Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 63 additions & 17 deletions oracletrace/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import json
import runpy
import csv
from .tracer import Tracer, TracerData
from collections import defaultdict

from .tracer import Tracer, TracerData, TracerMetadata, FunctionData
from .compare import compare_traces, ComparisonData
from typing import List, Dict, Any, Optional
from typing import List, Dict, Optional
from re import Pattern
from argparse import ArgumentParser, Namespace
from pathlib import Path
from dataclasses import asdict


Expand Down Expand Up @@ -48,6 +49,14 @@ def main() -> int:
action="store_true",
help="Hide functions which didn't run slower than baseline. Use with --compare"
)

parser.add_argument(
"--repeat",
metavar="NUMBER",
help="Number of times to run the trace against the previous trace JSON",
default=1
)

args: Namespace = parser.parse_args()

target: str = args.target
Expand Down Expand Up @@ -82,26 +91,63 @@ def main() -> int:
print(f"Regex error: {pattern} -> {e}", file=sys.stderr)
return 1

# Start tracing, run the script, then stop
tracer: Tracer = Tracer(root, ignore_patterns=ignore_patterns)
tracer.start()
try:
runpy.run_path(target, run_name="__main__")
finally:
tracer.stop()
def run_trace():
_tracer: Tracer = Tracer(root, ignore_patterns=ignore_patterns)

data: TracerData = tracer.get_trace_data()
_tracer.start()
try:
runpy.run_path(target, run_name="__main__")
finally:
_tracer.stop()

_data: TracerData = _tracer.get_trace_data()

return _tracer, _data

tracer, data = run_trace()

runs: int = int(args.repeat)
if runs > 1:
total_time: float = 0
tracer_function_aggs: Dict[str, FunctionData] = {}
for _ in range(runs):
# Start tracing, run the script, then stop
tracer, data = run_trace()

for function_data in data.functions:
if tracer_function_aggs.get(function_data.name) is None:
tracer_function_aggs[function_data.name] = function_data
else:
tracer_function_aggs[function_data.name].add(function_data)
total_time += function_data.total_time

tracer_agg: TracerData = TracerData(
metadata=TracerMetadata(
root_path=data.metadata.root_path,
total_functions=len(tracer_function_aggs),
total_execution_time=total_time
),
functions=list(tracer_function_aggs.values())
)

data = tracer_agg

def set_default(obj):
if isinstance(obj, set):
return list(obj)
raise TypeError

# Save json
if args.json:
with open(args.json, "w", encoding="utf-8") as f:
json.dump(asdict(data), f, indent=4)
json.dump(asdict(data), f, indent=4, default=set_default)

# Display the analysis
if args.top is not None:
tracer.show_results(args.top)
else:
tracer.show_results(None)
if runs <= 1:
if args.top:
tracer.show_results(int(args.top))
else:
tracer.show_results(None)

# Export as csv
if args.csv:
Expand Down Expand Up @@ -135,7 +181,7 @@ def main() -> int:
file=sys.stderr,
)
return 2


return 0

Expand Down
24 changes: 16 additions & 8 deletions oracletrace/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from rich.tree import Tree
from rich.table import Table
from rich import print
from typing import List, Optional, Callable, DefaultDict, Any, Tuple, Dict
from typing import List, Optional, Callable, DefaultDict, Any, Tuple, Dict, Self
from re import Pattern
from pathlib import Path
from types import FrameType
from dataclasses import dataclass
from dataclasses import dataclass, field


@dataclass
class TracerMetadata:
Expand All @@ -23,7 +23,16 @@ class FunctionData:
total_time: float
call_count: int
avg_time: float
callees: List[str]
callees: set[str] = field(default_factory=set)

def add(self, trace: type[Self]) -> None:
if trace.name != self.name:
return

self.callees.update(trace.callees)
self.total_time = (self.total_time + trace.total_time) / 2
self.call_count = (self.call_count + trace.call_count) // 2
self.avg_time = (self.avg_time + trace.avg_time) / 2

@dataclass
class TracerData:
Expand Down Expand Up @@ -188,13 +197,13 @@ def add_nodes(parent_node: Tree, parent_key: str, current_path: set[str]) -> Non
)

for child_key, count in sorted_children:
total_time = self._func_time[child_key]
_total_time = self._func_time[child_key]
# Detect recursion to prevent infinite loops in the tree
if child_key in current_path:
parent_node.add(f"[red]↻ {child_key}[/] ({count}x)")
continue

node_text = f"{child_key} [dim]({count}x, {total_time:.4f}s)[/]"
node_text = f"{child_key} [dim]({count}x, {_total_time:.4f}s)[/]"
child_node = parent_node.add(node_text)
add_nodes(child_node, child_key, current_path | {child_key})

Expand All @@ -207,14 +216,13 @@ def get_trace_data(self) -> TracerData:
for key, total_time in self._func_time.items():
calls = self._func_calls[key]
avg_time = total_time / calls if calls else 0

functions.append(
FunctionData(
name = key,
total_time = total_time,
call_count = calls,
avg_time = avg_time,
callees = list(self._call_map.get(key, {}).keys()),
callees = {k for k in self._call_map.get(key, {}).keys()},
)
)

Expand Down
17 changes: 11 additions & 6 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import importlib
import sys
from pathlib import Path
from oracletrace.tracer import TracerData, FunctionData, TracerMetadata
from oracletrace.tracer import TracerData, FunctionData, TracerMetadata, FunctionData
from oracletrace.compare import ComparisonData
from dataclasses import asdict
import pytest
Expand All @@ -16,6 +16,11 @@
assert str(REPO_ROOT / "oracletrace") in str(Path(cli.__file__).resolve())


def set_default(obj):
if isinstance(obj, set):
return list(obj)
raise TypeError

@pytest.fixture
def trace_data() -> TracerData:
return TracerData(
Expand Down Expand Up @@ -56,14 +61,14 @@ def baseline_trace_data() -> TracerData:
total_time = 1.5,
call_count = 3,
avg_time = 0.5,
callees=[]
callees=set()
),
FunctionData(
name = "bar",
total_time = 2.0,
call_count = 2,
avg_time = 1.0,
callees=[]
callees=set()
)
]
)
Expand Down Expand Up @@ -340,7 +345,7 @@ def test_main_fails_with_exit_2_on_regression(monkeypatch, tmp_path, empty_trace
target = tmp_path / "target.py"
target.write_text("print('hello')\n", encoding="utf-8")
compare_file = tmp_path / "baseline.json"
compare_file.write_text(json.dumps(asdict(empty_trace_data)), encoding="utf-8")
compare_file.write_text(json.dumps(asdict(empty_trace_data), default=set_default), encoding="utf-8")

monkeypatch.setattr(cli, "Tracer", lambda root, ignore_patterns: FakeTracer(root, ignore_patterns, empty_trace_data))
monkeypatch.setattr(cli.runpy, "run_path", lambda *args, **kwargs: None)
Expand Down Expand Up @@ -376,7 +381,7 @@ def test_main_returns_0_when_no_regression(monkeypatch, tmp_path, trace_data):
target = tmp_path / "target.py"
target.write_text("print('hello')\n", encoding="utf-8")
compare_file = tmp_path / "baseline.json"
compare_file.write_text(json.dumps(asdict(trace_data)), encoding="utf-8")
compare_file.write_text(json.dumps(asdict(trace_data), default=set_default), encoding="utf-8")

monkeypatch.setattr(cli, "Tracer", lambda root, ignore_patterns: FakeTracer(root, ignore_patterns, trace_data))
monkeypatch.setattr(cli.runpy, "run_path", lambda *args, **kwargs: None)
Expand Down Expand Up @@ -406,7 +411,7 @@ def test_main_shows_only_regressions(monkeypatch, tmp_path, trace_data, baseline
target = tmp_path / "target.py"
target.write_text("print('hello')\n", encoding="utf-8")
compare_file = tmp_path / "baseline.json"
compare_file.write_text(json.dumps(asdict(baseline_trace_data)), encoding="utf-8")
compare_file.write_text(json.dumps(asdict(baseline_trace_data), default=set_default), encoding="utf-8")

monkeypatch.setattr(cli, "Tracer", lambda root, ignore_patterns: FakeTracer(root, ignore_patterns, trace_data))
monkeypatch.setattr(cli.runpy, "run_path", lambda *args, **kwargs: None)
Expand Down
Loading