Skip to content

Commit 291ba1a

Browse files
committed
Fix flake8 warnings
1 parent f31f48e commit 291ba1a

File tree

6 files changed

+39
-9
lines changed

6 files changed

+39
-9
lines changed

investing_algorithm_framework/app/analysis/backtest_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def save_backtests_to_directory(
1414
backtests: List[Backtest],
1515
directory_path: Union[str, Path],
1616
dir_name_generation_function: Callable[[Backtest], str] = None,
17+
number_of_backtests_to_save: int = None,
1718
filter_function: Callable[[Backtest], bool] = None
1819
) -> None:
1920
"""
@@ -28,6 +29,8 @@ def save_backtests_to_directory(
2829
a string to be used as the directory name for that backtest.
2930
If not provided, the backtest's metadata 'id' will be used.
3031
Defaults to None.
32+
number_of_backtests_to_save (int, optional): Maximum number of
33+
backtests to save. If None, all backtests will be saved.
3134
filter_function (Callable[[Backtest], bool], optional): A function
3235
that takes a Backtest object as input and returns True if the
3336
backtest should be saved. Defaults to None.
@@ -41,6 +44,12 @@ def save_backtests_to_directory(
4144

4245
for backtest in backtests:
4346

47+
# Check if we have reached the limit of backtests to save
48+
if number_of_backtests_to_save is not None:
49+
if number_of_backtests_to_save <= 0:
50+
break
51+
number_of_backtests_to_save -= 1
52+
4453
if filter_function is not None:
4554
if not filter_function(backtest):
4655
continue
@@ -63,7 +72,8 @@ def save_backtests_to_directory(
6372

6473
def load_backtests_from_directory(
6574
directory_path: Union[str, Path],
66-
filter_function: Callable[[Backtest], bool] = None
75+
filter_function: Callable[[Backtest], bool] = None,
76+
number_of_backtests_to_load: int = None
6777
) -> List[Backtest]:
6878
"""
6979
Loads Backtest objects from the specified directory.
@@ -74,6 +84,8 @@ def load_backtests_from_directory(
7484
filter_function (Callable[[Backtest], bool], optional): A function
7585
that takes a Backtest object as input and returns True if the
7686
backtest should be included in the result. Defaults to None.
87+
number_of_backtests_to_load (int, optional): Maximum number of
88+
backtests to load. If None, all backtests will be loaded.
7789
7890
Returns:
7991
List[Backtest]: List of loaded Backtest objects.
@@ -89,6 +101,13 @@ def load_backtests_from_directory(
89101
return backtests
90102

91103
for file_name in os.listdir(directory_path):
104+
105+
# Check if we have reached the limit of backtests to load
106+
if number_of_backtests_to_load is not None:
107+
if number_of_backtests_to_load <= 0:
108+
break
109+
number_of_backtests_to_load -= 1
110+
92111
file_path = os.path.join(directory_path, file_name)
93112

94113
try:

investing_algorithm_framework/services/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
from .repository_service import RepositoryService
1414
from .trade_service import TradeService, TradeStopLossService, \
1515
TradeTakeProfitService
16-
from .metrics import get_annual_volatility, \
16+
from .metrics import get_annual_volatility, get_mean_daily_return, \
1717
get_sortino_ratio, get_drawdown_series, get_max_drawdown, \
1818
get_equity_curve, get_price_efficiency_ratio, get_sharpe_ratio, \
1919
get_profit_factor, get_cumulative_profit_factor_series, \
20-
get_rolling_profit_factor_series, \
20+
get_rolling_profit_factor_series, get_daily_returns_std, \
2121
get_cagr, get_standard_deviation_returns, \
2222
get_standard_deviation_downside_returns, \
2323
get_total_return, get_cumulative_exposure, get_exposure_ratio, \
@@ -41,6 +41,8 @@
4141
get_current_average_trade_gain, create_backtest_metrics_for_backtest
4242

4343
__all__ = [
44+
"get_mean_daily_return",
45+
"get_daily_returns_std",
4446
"OrderService",
4547
"RepositoryService",
4648
"PortfolioService",

investing_algorithm_framework/services/metrics/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,12 @@
3838
get_current_average_trade_duration, get_current_average_trade_gain, \
3939
get_current_average_trade_return, get_number_of_open_trades, \
4040
get_average_trade_duration
41+
from .mean_daily_return import get_mean_daily_return
42+
from .standard_deviation import get_daily_returns_std
4143

4244
__all__ = [
45+
"get_mean_daily_return",
46+
"get_daily_returns_std",
4347
"get_annual_volatility",
4448
"get_sortino_ratio",
4549
"get_drawdown_series",

investing_algorithm_framework/services/metrics/mean_daily_return.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,13 @@ def get_mean_daily_return(snapshots):
3838

3939
# Check if the period is less than a year
4040
if (end_date - start_date).days < 365:
41+
print("Less than a year of data, using CAGR to calculate mean daily return.")
4142
# Use CAGR to calculate mean daily return
4243
cagr = get_cagr(snapshots)
4344
if cagr == 0.0:
4445
return 0.0
46+
47+
print(f"CAGR: {cagr}")
4548
return (1 + cagr) ** (1 / 365) - 1
4649

4750
# Resample to daily frequency using last value of the day

investing_algorithm_framework/services/metrics/sharpe_ratio.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@
4444
4545
"""
4646

47-
from typing import Optional, List, Tuple
48-
4947
import math
50-
import pandas as pd
51-
import numpy as np
5248
from datetime import datetime
49+
from typing import List, Tuple
50+
51+
import numpy as np
52+
import pandas as pd
5353

5454
from investing_algorithm_framework.domain import PortfolioSnapshot
5555
from .mean_daily_return import get_mean_daily_return
@@ -76,7 +76,10 @@ def get_sharpe_ratio(
7676
"""
7777
snapshots = sorted(snapshots, key=lambda s: s.created_at)
7878
mean_daily_return = get_mean_daily_return(snapshots)
79+
80+
print(f"mean daily return {mean_daily_return}")
7981
std_daily_return = get_daily_returns_std(snapshots)
82+
print(f"std daily return {std_daily_return}")
8083

8184
if std_daily_return == 0:
8285
return float('nan') # Avoid division by zero

investing_algorithm_framework/services/metrics/standard_deviation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,8 @@ def get_daily_returns_std(snapshots):
108108
df["created_at"] = pd.to_datetime(df["created_at"])
109109
df = df.drop_duplicates("created_at").set_index("created_at")
110110
df = df.sort_index()
111-
112111
# Resample to daily frequency (end of day)
113-
daily_df = df.resample("D").last().dropna()
112+
daily_df = df.resample("D").last().ffill().dropna()
114113

115114
# Calculate daily returns
116115
daily_df["return"] = daily_df["total_value"].pct_change().dropna()

0 commit comments

Comments
 (0)