Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import unittest
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -83,6 +84,92 @@ def test_cli_invalid_pattern(self):
self.assertNotEqual(result.exit_code, 0)
self.assertIn("Invalid value for '--pattern'", result.output)

@patch("tracestorm.cli.run_load_test")
@patch("tracestorm.cli.os.makedirs")
@patch("tracestorm.cli.datetime")
def test_cli_with_output_dir(
self, mock_datetime, mock_makedirs, mock_run_load_test
):
"""Test CLI with output directory option."""
mock_analyzer = MagicMock()
mock_run_load_test.return_value = ([], mock_analyzer)
mock_datetime.datetime.now.return_value.strftime.return_value = (
"20240101_120000"
)

# Test with explicit output dir
result = self.runner.invoke(
main,
[
"--model",
"gpt-3.5-turbo",
"--output-dir",
"custom_output_dir",
],
)

self.assertEqual(result.exit_code, 0)
mock_makedirs.assert_called_with("custom_output_dir", exist_ok=True)
mock_analyzer.export_json.assert_called_once()

# Reset mocks
mock_makedirs.reset_mock()
mock_analyzer.reset_mock()

# Test with default output dir
result = self.runner.invoke(
main,
[
"--model",
"gpt-3.5-turbo",
],
)

self.assertEqual(result.exit_code, 0)
mock_makedirs.assert_called_with(
os.path.join("tracestorm_results", "20240101_120000"), exist_ok=True
)
mock_analyzer.export_json.assert_called_once()

@patch("tracestorm.cli.run_load_test")
@patch("tracestorm.cli.os.makedirs")
def test_cli_with_plot_option(self, mock_makedirs, mock_run_load_test):
"""Test CLI with plot option."""
mock_analyzer = MagicMock()
mock_run_load_test.return_value = ([], mock_analyzer)

# Test with plot enabled
result = self.runner.invoke(
main,
[
"--model",
"gpt-3.5-turbo",
"--plot",
"--output-dir",
"test_dir",
],
)

self.assertEqual(result.exit_code, 0)
mock_analyzer.plot_cdf.assert_called_once()

# Reset mock
mock_analyzer.reset_mock()

# Test with plot disabled (default)
result = self.runner.invoke(
main,
[
"--model",
"gpt-3.5-turbo",
"--output-dir",
"test_dir",
],
)

self.assertEqual(result.exit_code, 0)
mock_analyzer.plot_cdf.assert_not_called()


if __name__ == "__main__":
unittest.main()
58 changes: 58 additions & 0 deletions tests/test_replayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,64 @@ async def mock_request_raising(*args, **kwargs):
"The 'error' field should match our simulated exception message.",
)

async def _dummy_coro(self):
# Simple coroutine that does nothing
return

@patch("tracestorm.trace_player.logger")
async def async_test_sender_worker(self, mock_logger):
"""Test the sender_worker method with various queue states."""
player = TracePlayer(
self.name,
self.trace,
self.requests,
self.base_url,
self.api_key,
self.ipc_queue,
)

# Set up the dispatch queue and add one item
player.dispatch_queue = asyncio.Queue()
await player.dispatch_queue.put((1000, {"prompt": "Test prompt"}))

# Mock the request method
player.request = AsyncMock(
return_value={
"result": "mock_result",
"token_count": 10,
"time_records": [time.time()],
"error": None,
}
)

# Setup - run the sender_worker for a bit then set shutdown flag
task = asyncio.create_task(player.sender_worker())
await asyncio.sleep(0.1) # Let it process the queued item

# Queue is now empty, sender_worker should call sleep
await asyncio.sleep(0.2)

# Set shutdown flag to stop the worker gracefully
player.shutdown_flag.set()
await asyncio.sleep(0.2)

# The task should complete since we set the shutdown flag
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

# Check that request was called with our test data
player.request.assert_called_once()
args, _ = player.request.call_args
self.assertEqual(args[0], {"prompt": "Test prompt"})

def test_sender_worker(self):
"""Test wrapper for async_test_sender_worker."""
asyncio.run(self.async_test_sender_worker())


if __name__ == "__main__":
unittest.main()
44 changes: 43 additions & 1 deletion tracestorm/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import os
from typing import Optional, Tuple

Expand Down Expand Up @@ -112,6 +113,23 @@ def create_trace_generator(
@click.option(
"--datasets-config", default=None, help="Config file for datasets"
)
@click.option(
"--plot",
is_flag=True,
default=False,
help="Generate performance plots",
)
@click.option(
"--output-dir",
default=None,
help="Directory to save results (defaults to tracestorm_results/{timestamp})",
)
@click.option(
"--include-raw-results",
is_flag=True,
default=False,
help="Include raw results in the output",
)
def main(
model,
rps,
Expand All @@ -122,9 +140,20 @@ def main(
base_url,
api_key,
datasets_config,
plot,
output_dir,
include_raw_results,
):
"""Run trace-based load testing for OpenAI API endpoints."""
try:
# Set up output directory
if output_dir is None:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join("tracestorm_results", timestamp)

os.makedirs(output_dir, exist_ok=True)
logger.info(f"Results will be saved to: {output_dir}")

trace_generator, warning_msg = create_trace_generator(
pattern, rps, duration, seed
)
Expand All @@ -149,7 +178,20 @@ def main(
)

print(result_analyzer)
result_analyzer.plot_cdf()

# Save raw results (always)
results_file = os.path.join(output_dir, "results.json")
result_analyzer.export_json(
results_file, include_raw=include_raw_results
)
logger.info(f"Raw results saved to: {results_file}")

# Only generate plots if requested
if plot:
ttft_file = os.path.join(output_dir, "ttft_cdf.png")
tpot_file = os.path.join(output_dir, "tpot_cdf.png")
result_analyzer.plot_cdf(ttft_file=ttft_file, tpot_file=tpot_file)
logger.info("Performance plots generated")

except ValueError as e:
logger.error(f"Configuration error: {str(e)}")
Expand Down
13 changes: 13 additions & 0 deletions tracestorm/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import multiprocessing
from queue import Empty
from typing import List, Optional, Tuple

from tracestorm.logger import init_logger
Expand Down Expand Up @@ -83,6 +84,18 @@ def run_load_test(
f"Received result from {name} for timestamp {timestamp}: {resp['token_count']} tokens"
)
aggregated_results.append((name, timestamp, resp))
except Empty:
# Timeout occurred, but maybe not all processes are finished.
logger.warning(
"No results received from IPC queue in the last 30 seconds. Waiting..."
)
# Check if all processes are still alive before continuing
if not any(p.is_alive() for p in processes):
logger.warning(
"All processes seem to have finished. Stopping result collection."
)
break
continue
except Exception as e:
logger.error(f"Error collecting results: {str(e)}", exc_info=True)
break
Expand Down
15 changes: 9 additions & 6 deletions tracestorm/trace_player.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ async def sender_worker(self) -> None:
"""
while not (self.shutdown_flag.is_set() and self.dispatch_queue.empty()):
try:
# Attempt to get a queued item; if none are available quickly, re-check shutdown_flag.
item = await asyncio.wait_for(
self.dispatch_queue.get(), timeout=0.1
)
except asyncio.TimeoutError:
continue # No tasks currently, keep looping.
item = self.dispatch_queue.get_nowait()
except asyncio.QueueEmpty:
if self.shutdown_flag.is_set():
break
await asyncio.sleep(0.1)
continue

timestamp, request_data = item
logger.info(
Expand Down Expand Up @@ -152,6 +152,9 @@ async def schedule_requests(self) -> None:
await asyncio.sleep(delay)

request_data = self.requests[i]
# Wait if queue is full
while self.dispatch_queue.full():
await asyncio.sleep(0.1)
# We put both the scheduled timestamp and the request data into the queue.
await self.dispatch_queue.put((scheduled_time, request_data))

Expand Down