Skip to content

Commit 93affcb

Browse files
feat: add a profiled predict mode
1 parent 44782c0 commit 93affcb

4 files changed

Lines changed: 42 additions & 3 deletions

File tree

src/analyzer/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from .data_loader import load_chessfile_predict, load_chessfile_train
22
from .labels import label_to_vector, vector_to_label
3-
from .modes import predict_mode, train_mode
3+
from .modes import predict_mode, predict_mode_profiled, train_mode
44

55
__all__ = [
66
"load_chessfile_predict",
77
"load_chessfile_train",
88
"label_to_vector",
99
"vector_to_label",
1010
"predict_mode",
11+
"predict_mode_profiled",
1112
"train_mode",
1213
]

src/analyzer/modes.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .data_loader import load_chessfile_predict, load_chessfile_train
88
from .labels import vector_to_label
9+
from .profile import profile_it
910

1011

1112
def process_fen(network: Network, encoding: str, fen: str):
@@ -30,6 +31,18 @@ def predict_mode(
3031
print(result)
3132

3233

34+
@profile_it
35+
def predict_mode_profiled(
36+
network: Network, chessfile: str, encoding: str = "simple"
37+
) -> None:
38+
fens = load_chessfile_predict(chessfile)
39+
40+
dispatch = functools.partial(process_fen, network, encoding)
41+
42+
for c, fen in enumerate(fens):
43+
print(c, dispatch(fen))
44+
45+
3346
def train_mode(
3447
network: Network,
3548
chessfile: str,

src/analyzer/profile.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import cProfile
2+
from pstats import SortKey, Stats
3+
4+
5+
def profile_it(func):
6+
def wrapper(*args, **kwargs):
7+
with cProfile.Profile() as pr:
8+
result = func(*args, **kwargs)
9+
10+
Stats(pr).sort_stats(SortKey.CUMULATIVE).print_stats()
11+
return result
12+
13+
return wrapper

src/my_torch/tools/my_torch_analyzer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#!/usr/bin/env python3
22
import argparse
33
import sys
4+
import time
45

5-
from analyzer import predict_mode, train_mode
6+
from analyzer import predict_mode, predict_mode_profiled, train_mode
67
from my_torch import Network
78

89

@@ -52,6 +53,10 @@ def main():
5253
help="Board encoding method (default: simple)",
5354
)
5455

56+
parser.add_argument(
57+
"--profiled", action="store_true", help=argparse.SUPPRESS
58+
)
59+
5560
args = parser.parse_args()
5661

5762
try:
@@ -66,7 +71,14 @@ def main():
6671
return 84
6772

6873
if args.predict:
69-
predict_mode(network, args.chessfile, encoding=args.encoding)
74+
if args.profiled:
75+
print("Using profiled version, single threaded.")
76+
time.sleep(1)
77+
predict_mode_profiled(
78+
network, args.chessfile, encoding=args.encoding
79+
)
80+
else:
81+
predict_mode(network, args.chessfile, encoding=args.encoding)
7082
else:
7183
savefile = args.save if args.save else args.loadfile
7284
train_mode(

0 commit comments

Comments
 (0)