diff --git a/tests/test_cli.py b/tests/test_cli.py index 8ed12a2..09280ab 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,3 +1,4 @@ +import os import unittest from unittest.mock import MagicMock, patch @@ -83,6 +84,92 @@ def test_cli_invalid_pattern(self): self.assertNotEqual(result.exit_code, 0) self.assertIn("Invalid value for '--pattern'", result.output) + @patch("tracestorm.cli.run_load_test") + @patch("tracestorm.cli.os.makedirs") + @patch("tracestorm.cli.datetime") + def test_cli_with_output_dir( + self, mock_datetime, mock_makedirs, mock_run_load_test + ): + """Test CLI with output directory option.""" + mock_analyzer = MagicMock() + mock_run_load_test.return_value = ([], mock_analyzer) + mock_datetime.datetime.now.return_value.strftime.return_value = ( + "20240101_120000" + ) + + # Test with explicit output dir + result = self.runner.invoke( + main, + [ + "--model", + "gpt-3.5-turbo", + "--output-dir", + "custom_output_dir", + ], + ) + + self.assertEqual(result.exit_code, 0) + mock_makedirs.assert_called_with("custom_output_dir", exist_ok=True) + mock_analyzer.export_json.assert_called_once() + + # Reset mocks + mock_makedirs.reset_mock() + mock_analyzer.reset_mock() + + # Test with default output dir + result = self.runner.invoke( + main, + [ + "--model", + "gpt-3.5-turbo", + ], + ) + + self.assertEqual(result.exit_code, 0) + mock_makedirs.assert_called_with( + os.path.join("tracestorm_results", "20240101_120000"), exist_ok=True + ) + mock_analyzer.export_json.assert_called_once() + + @patch("tracestorm.cli.run_load_test") + @patch("tracestorm.cli.os.makedirs") + def test_cli_with_plot_option(self, mock_makedirs, mock_run_load_test): + """Test CLI with plot option.""" + mock_analyzer = MagicMock() + mock_run_load_test.return_value = ([], mock_analyzer) + + # Test with plot enabled + result = self.runner.invoke( + main, + [ + "--model", + "gpt-3.5-turbo", + "--plot", + "--output-dir", + "test_dir", + ], + ) + + self.assertEqual(result.exit_code, 0) + mock_analyzer.plot_cdf.assert_called_once() + + # Reset mock + mock_analyzer.reset_mock() + + # Test with plot disabled (default) + result = self.runner.invoke( + main, + [ + "--model", + "gpt-3.5-turbo", + "--output-dir", + "test_dir", + ], + ) + + self.assertEqual(result.exit_code, 0) + mock_analyzer.plot_cdf.assert_not_called() + if __name__ == "__main__": unittest.main() diff --git a/tests/test_replayer.py b/tests/test_replayer.py index aaeabc8..f125fee 100644 --- a/tests/test_replayer.py +++ b/tests/test_replayer.py @@ -126,6 +126,64 @@ async def mock_request_raising(*args, **kwargs): "The 'error' field should match our simulated exception message.", ) + async def _dummy_coro(self): + # Simple coroutine that does nothing + return + + @patch("tracestorm.trace_player.logger") + async def async_test_sender_worker(self, mock_logger): + """Test the sender_worker method with various queue states.""" + player = TracePlayer( + self.name, + self.trace, + self.requests, + self.base_url, + self.api_key, + self.ipc_queue, + ) + + # Set up the dispatch queue and add one item + player.dispatch_queue = asyncio.Queue() + await player.dispatch_queue.put((1000, {"prompt": "Test prompt"})) + + # Mock the request method + player.request = AsyncMock( + return_value={ + "result": "mock_result", + "token_count": 10, + "time_records": [time.time()], + "error": None, + } + ) + + # Setup - run the sender_worker for a bit then set shutdown flag + task = asyncio.create_task(player.sender_worker()) + await asyncio.sleep(0.1) # Let it process the queued item + + # Queue is now empty, sender_worker should call sleep + await asyncio.sleep(0.2) + + # Set shutdown flag to stop the worker gracefully + player.shutdown_flag.set() + await asyncio.sleep(0.2) + + # The task should complete since we set the shutdown flag + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # Check that request was called with our test data + player.request.assert_called_once() + args, _ = player.request.call_args + self.assertEqual(args[0], {"prompt": "Test prompt"}) + + def test_sender_worker(self): + """Test wrapper for async_test_sender_worker.""" + asyncio.run(self.async_test_sender_worker()) + if __name__ == "__main__": unittest.main() diff --git a/tracestorm/cli.py b/tracestorm/cli.py index 2647858..2def5d4 100644 --- a/tracestorm/cli.py +++ b/tracestorm/cli.py @@ -1,3 +1,4 @@ +import datetime import os from typing import Optional, Tuple @@ -112,6 +113,23 @@ def create_trace_generator( @click.option( "--datasets-config", default=None, help="Config file for datasets" ) +@click.option( + "--plot", + is_flag=True, + default=False, + help="Generate performance plots", +) +@click.option( + "--output-dir", + default=None, + help="Directory to save results (defaults to tracestorm_results/{timestamp})", +) +@click.option( + "--include-raw-results", + is_flag=True, + default=False, + help="Include raw results in the output", +) def main( model, rps, @@ -122,9 +140,20 @@ def main( base_url, api_key, datasets_config, + plot, + output_dir, + include_raw_results, ): """Run trace-based load testing for OpenAI API endpoints.""" try: + # Set up output directory + if output_dir is None: + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = os.path.join("tracestorm_results", timestamp) + + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Results will be saved to: {output_dir}") + trace_generator, warning_msg = create_trace_generator( pattern, rps, duration, seed ) @@ -149,7 +178,20 @@ def main( ) print(result_analyzer) - result_analyzer.plot_cdf() + + # Save raw results (always) + results_file = os.path.join(output_dir, "results.json") + result_analyzer.export_json( + results_file, include_raw=include_raw_results + ) + logger.info(f"Raw results saved to: {results_file}") + + # Only generate plots if requested + if plot: + ttft_file = os.path.join(output_dir, "ttft_cdf.png") + tpot_file = os.path.join(output_dir, "tpot_cdf.png") + result_analyzer.plot_cdf(ttft_file=ttft_file, tpot_file=tpot_file) + logger.info("Performance plots generated") except ValueError as e: logger.error(f"Configuration error: {str(e)}") diff --git a/tracestorm/core.py b/tracestorm/core.py index 81ebef7..4eb81b3 100644 --- a/tracestorm/core.py +++ b/tracestorm/core.py @@ -1,4 +1,5 @@ import multiprocessing +from queue import Empty from typing import List, Optional, Tuple from tracestorm.logger import init_logger @@ -83,6 +84,18 @@ def run_load_test( f"Received result from {name} for timestamp {timestamp}: {resp['token_count']} tokens" ) aggregated_results.append((name, timestamp, resp)) + except Empty: + # Timeout occurred, but maybe not all processes are finished. + logger.warning( + "No results received from IPC queue in the last 30 seconds. Waiting..." + ) + # Check if all processes are still alive before continuing + if not any(p.is_alive() for p in processes): + logger.warning( + "All processes seem to have finished. Stopping result collection." + ) + break + continue except Exception as e: logger.error(f"Error collecting results: {str(e)}", exc_info=True) break diff --git a/tracestorm/trace_player.py b/tracestorm/trace_player.py index 3bb6f30..10ee76d 100644 --- a/tracestorm/trace_player.py +++ b/tracestorm/trace_player.py @@ -115,12 +115,12 @@ async def sender_worker(self) -> None: """ while not (self.shutdown_flag.is_set() and self.dispatch_queue.empty()): try: - # Attempt to get a queued item; if none are available quickly, re-check shutdown_flag. - item = await asyncio.wait_for( - self.dispatch_queue.get(), timeout=0.1 - ) - except asyncio.TimeoutError: - continue # No tasks currently, keep looping. + item = self.dispatch_queue.get_nowait() + except asyncio.QueueEmpty: + if self.shutdown_flag.is_set(): + break + await asyncio.sleep(0.1) + continue timestamp, request_data = item logger.info( @@ -152,6 +152,9 @@ async def schedule_requests(self) -> None: await asyncio.sleep(delay) request_data = self.requests[i] + # Wait if queue is full + while self.dispatch_queue.full(): + await asyncio.sleep(0.1) # We put both the scheduled timestamp and the request data into the queue. await self.dispatch_queue.put((scheduled_time, request_data))