-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvisualize.py
More file actions
57 lines (51 loc) · 1.97 KB
/
visualize.py
File metadata and controls
57 lines (51 loc) · 1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import plotly.graph_objs as go
def plot_stock(df, ticker):
fig = go.Figure()
x = df['Date'] if 'Date' in df.columns else df.index
fig.add_trace(go.Scatter(x=x, y=df['Close'], name='Close'))
if 'SMA_20' in df.columns:
fig.add_trace(go.Scatter(x=x, y=df['SMA_20'], name='SMA 20'))
if 'RSI' in df.columns:
fig.add_trace(go.Scatter(
x=x, y=df['RSI'], name='RSI',
yaxis='y2', line=dict(color='orange', dash='dot')
))
fig.update_layout(
title=f"{ticker} Price, SMA 20, and RSI",
xaxis=dict(title='Date'),
yaxis=dict(title='Price'),
yaxis2=dict(
title='RSI',
overlaying='y',
side='right',
range=[0, 100],
showgrid=False
),
height=600
)
fig.show()
def plot_forecast(df, forecast, conf_int, ticker):
import pandas as pd
# Prepare forecast dates
last_date = pd.to_datetime(df['Date'].iloc[-1]) if 'Date' in df.columns else pd.to_datetime(df.index[-1])
forecast_dates = pd.date_range(last_date + pd.Timedelta(days=1), periods=len(forecast), freq='B')
fig = go.Figure()
# Plot historical close
x = df['Date'] if 'Date' in df.columns else df.index
fig.add_trace(go.Scatter(x=x, y=df['Close'], name='Close'))
# Plot forecast
fig.add_trace(go.Scatter(x=forecast_dates, y=forecast, name='Forecast', line=dict(color='green', dash='dash')))
# Plot confidence interval
fig.add_trace(go.Scatter(
x=forecast_dates, y=conf_int.iloc[:, 0], fill=None, mode='lines', line=dict(color='lightgreen'), name='Lower CI'
))
fig.add_trace(go.Scatter(
x=forecast_dates, y=conf_int.iloc[:, 1], fill='tonexty', mode='lines', line=dict(color='lightgreen'), name='Upper CI'
))
fig.update_layout(
title=f"{ticker} Forecasted Close Prices with Confidence Interval",
xaxis=dict(title='Date'),
yaxis=dict(title='Price'),
height=600
)
fig.show()