diff --git a/oracletrace/cli.py b/oracletrace/cli.py index a531166..cb21611 100644 --- a/oracletrace/cli.py +++ b/oracletrace/cli.py @@ -4,12 +4,13 @@ import json import runpy import csv -from .tracer import Tracer, TracerData +from collections import defaultdict + +from .tracer import Tracer, TracerData, TracerMetadata, FunctionData from .compare import compare_traces, ComparisonData -from typing import List, Dict, Any, Optional +from typing import List, Dict, Optional from re import Pattern from argparse import ArgumentParser, Namespace -from pathlib import Path from dataclasses import asdict @@ -48,6 +49,14 @@ def main() -> int: action="store_true", help="Hide functions which didn't run slower than baseline. Use with --compare" ) + + parser.add_argument( + "--repeat", + metavar="NUMBER", + help="Number of times to run the trace against the previous trace JSON", + default=1 + ) + args: Namespace = parser.parse_args() target: str = args.target @@ -82,26 +91,63 @@ def main() -> int: print(f"Regex error: {pattern} -> {e}", file=sys.stderr) return 1 - # Start tracing, run the script, then stop - tracer: Tracer = Tracer(root, ignore_patterns=ignore_patterns) - tracer.start() - try: - runpy.run_path(target, run_name="__main__") - finally: - tracer.stop() + def run_trace(): + _tracer: Tracer = Tracer(root, ignore_patterns=ignore_patterns) - data: TracerData = tracer.get_trace_data() + _tracer.start() + try: + runpy.run_path(target, run_name="__main__") + finally: + _tracer.stop() + + _data: TracerData = _tracer.get_trace_data() + + return _tracer, _data + + tracer, data = run_trace() + + runs: int = int(args.repeat) + if runs > 1: + total_time: float = 0 + tracer_function_aggs: Dict[str, FunctionData] = {} + for _ in range(runs): + # Start tracing, run the script, then stop + tracer, data = run_trace() + + for function_data in data.functions: + if tracer_function_aggs.get(function_data.name) is None: + tracer_function_aggs[function_data.name] = function_data + else: + tracer_function_aggs[function_data.name].add(function_data) + total_time += function_data.total_time + + tracer_agg: TracerData = TracerData( + metadata=TracerMetadata( + root_path=data.metadata.root_path, + total_functions=len(tracer_function_aggs), + total_execution_time=total_time + ), + functions=list(tracer_function_aggs.values()) + ) + + data = tracer_agg + + def set_default(obj): + if isinstance(obj, set): + return list(obj) + raise TypeError # Save json if args.json: with open(args.json, "w", encoding="utf-8") as f: - json.dump(asdict(data), f, indent=4) + json.dump(asdict(data), f, indent=4, default=set_default) # Display the analysis - if args.top is not None: - tracer.show_results(args.top) - else: - tracer.show_results(None) + if runs <= 1: + if args.top: + tracer.show_results(int(args.top)) + else: + tracer.show_results(None) # Export as csv if args.csv: @@ -135,7 +181,7 @@ def main() -> int: file=sys.stderr, ) return 2 - + return 0 diff --git a/oracletrace/tracer.py b/oracletrace/tracer.py index 31fcbe1..f44f621 100644 --- a/oracletrace/tracer.py +++ b/oracletrace/tracer.py @@ -5,11 +5,11 @@ from rich.tree import Tree from rich.table import Table from rich import print -from typing import List, Optional, Callable, DefaultDict, Any, Tuple, Dict +from typing import List, Optional, Callable, DefaultDict, Any, Tuple, Dict, Self from re import Pattern -from pathlib import Path from types import FrameType -from dataclasses import dataclass +from dataclasses import dataclass, field + @dataclass class TracerMetadata: @@ -23,7 +23,16 @@ class FunctionData: total_time: float call_count: int avg_time: float - callees: List[str] + callees: set[str] = field(default_factory=set) + + def add(self, trace: type[Self]) -> None: + if trace.name != self.name: + return + + self.callees.update(trace.callees) + self.total_time = (self.total_time + trace.total_time) / 2 + self.call_count = (self.call_count + trace.call_count) // 2 + self.avg_time = (self.avg_time + trace.avg_time) / 2 @dataclass class TracerData: @@ -188,13 +197,13 @@ def add_nodes(parent_node: Tree, parent_key: str, current_path: set[str]) -> Non ) for child_key, count in sorted_children: - total_time = self._func_time[child_key] + _total_time = self._func_time[child_key] # Detect recursion to prevent infinite loops in the tree if child_key in current_path: parent_node.add(f"[red]↻ {child_key}[/] ({count}x)") continue - node_text = f"{child_key} [dim]({count}x, {total_time:.4f}s)[/]" + node_text = f"{child_key} [dim]({count}x, {_total_time:.4f}s)[/]" child_node = parent_node.add(node_text) add_nodes(child_node, child_key, current_path | {child_key}) @@ -207,14 +216,13 @@ def get_trace_data(self) -> TracerData: for key, total_time in self._func_time.items(): calls = self._func_calls[key] avg_time = total_time / calls if calls else 0 - functions.append( FunctionData( name = key, total_time = total_time, call_count = calls, avg_time = avg_time, - callees = list(self._call_map.get(key, {}).keys()), + callees = {k for k in self._call_map.get(key, {}).keys()}, ) ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 6f2532c..47e9a30 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2,7 +2,7 @@ import importlib import sys from pathlib import Path -from oracletrace.tracer import TracerData, FunctionData, TracerMetadata +from oracletrace.tracer import TracerData, FunctionData, TracerMetadata, FunctionData from oracletrace.compare import ComparisonData from dataclasses import asdict import pytest @@ -16,6 +16,11 @@ assert str(REPO_ROOT / "oracletrace") in str(Path(cli.__file__).resolve()) +def set_default(obj): + if isinstance(obj, set): + return list(obj) + raise TypeError + @pytest.fixture def trace_data() -> TracerData: return TracerData( @@ -56,14 +61,14 @@ def baseline_trace_data() -> TracerData: total_time = 1.5, call_count = 3, avg_time = 0.5, - callees=[] + callees=set() ), FunctionData( name = "bar", total_time = 2.0, call_count = 2, avg_time = 1.0, - callees=[] + callees=set() ) ] ) @@ -340,7 +345,7 @@ def test_main_fails_with_exit_2_on_regression(monkeypatch, tmp_path, empty_trace target = tmp_path / "target.py" target.write_text("print('hello')\n", encoding="utf-8") compare_file = tmp_path / "baseline.json" - compare_file.write_text(json.dumps(asdict(empty_trace_data)), encoding="utf-8") + compare_file.write_text(json.dumps(asdict(empty_trace_data), default=set_default), encoding="utf-8") monkeypatch.setattr(cli, "Tracer", lambda root, ignore_patterns: FakeTracer(root, ignore_patterns, empty_trace_data)) monkeypatch.setattr(cli.runpy, "run_path", lambda *args, **kwargs: None) @@ -376,7 +381,7 @@ def test_main_returns_0_when_no_regression(monkeypatch, tmp_path, trace_data): target = tmp_path / "target.py" target.write_text("print('hello')\n", encoding="utf-8") compare_file = tmp_path / "baseline.json" - compare_file.write_text(json.dumps(asdict(trace_data)), encoding="utf-8") + compare_file.write_text(json.dumps(asdict(trace_data), default=set_default), encoding="utf-8") monkeypatch.setattr(cli, "Tracer", lambda root, ignore_patterns: FakeTracer(root, ignore_patterns, trace_data)) monkeypatch.setattr(cli.runpy, "run_path", lambda *args, **kwargs: None) @@ -406,7 +411,7 @@ def test_main_shows_only_regressions(monkeypatch, tmp_path, trace_data, baseline target = tmp_path / "target.py" target.write_text("print('hello')\n", encoding="utf-8") compare_file = tmp_path / "baseline.json" - compare_file.write_text(json.dumps(asdict(baseline_trace_data)), encoding="utf-8") + compare_file.write_text(json.dumps(asdict(baseline_trace_data), default=set_default), encoding="utf-8") monkeypatch.setattr(cli, "Tracer", lambda root, ignore_patterns: FakeTracer(root, ignore_patterns, trace_data)) monkeypatch.setattr(cli.runpy, "run_path", lambda *args, **kwargs: None)