updated title of plots
Browse files- make_plot.py +51 -42
make_plot.py
CHANGED
@@ -21,37 +21,41 @@ def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure:
|
|
21 |
fig = go.Figure()
|
22 |
|
23 |
# Add the first scatter plot with steelblue color
|
24 |
-
fig.add_trace(
|
|
|
25 |
x=df1.index,
|
26 |
y=df1.iloc[:, 0],
|
27 |
-
mode=
|
28 |
-
name=
|
29 |
-
line=dict(color=
|
30 |
-
marker=dict(color=
|
31 |
-
|
|
|
32 |
|
33 |
# Add the second scatter plot with yellow color
|
34 |
-
fig.add_trace(
|
|
|
35 |
x=df2.index,
|
36 |
y=df2.iloc[:, 0],
|
37 |
-
mode=
|
38 |
-
name=
|
39 |
-
line=dict(color=
|
40 |
-
marker=dict(color=
|
41 |
-
|
|
|
42 |
|
43 |
# Customize the layout
|
44 |
fig.update_layout(
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
return fig
|
52 |
|
53 |
|
54 |
-
def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame])
|
55 |
"""
|
56 |
Plot the true values and forecasts using Plotly.
|
57 |
|
@@ -67,48 +71,53 @@ def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]) -> go.Figure:
|
|
67 |
fig = go.Figure()
|
68 |
|
69 |
# Add the true values trace
|
70 |
-
fig.add_trace(
|
|
|
71 |
x=pd.to_datetime(df.index),
|
72 |
y=df.iloc[:, 0],
|
73 |
-
mode=
|
74 |
-
name=
|
75 |
-
line=dict(color=
|
76 |
-
|
|
|
77 |
|
78 |
# Add the forecast traces
|
79 |
colors = ["green", "blue", "purple"]
|
80 |
for i, forecast in enumerate(forecasts):
|
81 |
color = colors[i]
|
82 |
for sample in forecast.samples:
|
83 |
-
fig.add_trace(
|
|
|
84 |
x=forecast.index.to_timestamp(),
|
85 |
y=sample,
|
86 |
-
mode=
|
87 |
opacity=0.15, # Adjust opacity to control visibility of individual samples
|
88 |
-
name=f
|
89 |
showlegend=False, # Hide the individual forecast series from the legend
|
90 |
-
hoverinfo=
|
91 |
-
line=dict(color=color)
|
92 |
-
|
|
|
93 |
# Add the average
|
94 |
mean_forecast = np.mean(forecast.samples, axis=0)
|
95 |
-
fig.add_trace(
|
|
|
96 |
x=forecast.index.to_timestamp(),
|
97 |
y=mean_forecast,
|
98 |
-
mode=
|
99 |
-
name=
|
100 |
-
line=dict(color=
|
101 |
-
|
|
|
102 |
|
103 |
# Customize the layout
|
104 |
fig.update_layout(
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
)
|
112 |
|
113 |
# Return the figure
|
114 |
return fig
|
|
|
21 |
fig = go.Figure()
|
22 |
|
23 |
# Add the first scatter plot with steelblue color
|
24 |
+
fig.add_trace(
|
25 |
+
go.Scatter(
|
26 |
x=df1.index,
|
27 |
y=df1.iloc[:, 0],
|
28 |
+
mode="lines",
|
29 |
+
name="Training Data",
|
30 |
+
line=dict(color="steelblue"),
|
31 |
+
marker=dict(color="steelblue"),
|
32 |
+
)
|
33 |
+
)
|
34 |
|
35 |
# Add the second scatter plot with yellow color
|
36 |
+
fig.add_trace(
|
37 |
+
go.Scatter(
|
38 |
x=df2.index,
|
39 |
y=df2.iloc[:, 0],
|
40 |
+
mode="lines",
|
41 |
+
name="Test Data",
|
42 |
+
line=dict(color="gold"),
|
43 |
+
marker=dict(color="gold"),
|
44 |
+
)
|
45 |
+
)
|
46 |
|
47 |
# Customize the layout
|
48 |
fig.update_layout(
|
49 |
+
title="Univariate Time Series",
|
50 |
+
xaxis=dict(title="Date"),
|
51 |
+
yaxis=dict(title="Value"),
|
52 |
+
showlegend=True,
|
53 |
+
template="plotly_white",
|
54 |
+
)
|
55 |
return fig
|
56 |
|
57 |
|
58 |
+
def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]):
|
59 |
"""
|
60 |
Plot the true values and forecasts using Plotly.
|
61 |
|
|
|
71 |
fig = go.Figure()
|
72 |
|
73 |
# Add the true values trace
|
74 |
+
fig.add_trace(
|
75 |
+
go.Scatter(
|
76 |
x=pd.to_datetime(df.index),
|
77 |
y=df.iloc[:, 0],
|
78 |
+
mode="lines",
|
79 |
+
name="True values",
|
80 |
+
line=dict(color="black"),
|
81 |
+
)
|
82 |
+
)
|
83 |
|
84 |
# Add the forecast traces
|
85 |
colors = ["green", "blue", "purple"]
|
86 |
for i, forecast in enumerate(forecasts):
|
87 |
color = colors[i]
|
88 |
for sample in forecast.samples:
|
89 |
+
fig.add_trace(
|
90 |
+
go.Scatter(
|
91 |
x=forecast.index.to_timestamp(),
|
92 |
y=sample,
|
93 |
+
mode="lines",
|
94 |
opacity=0.15, # Adjust opacity to control visibility of individual samples
|
95 |
+
name=f"Forecast {i + 1}",
|
96 |
showlegend=False, # Hide the individual forecast series from the legend
|
97 |
+
hoverinfo="none", # Disable hover information for the forecast series
|
98 |
+
line=dict(color=color),
|
99 |
+
)
|
100 |
+
)
|
101 |
# Add the average
|
102 |
mean_forecast = np.mean(forecast.samples, axis=0)
|
103 |
+
fig.add_trace(
|
104 |
+
go.Scatter(
|
105 |
x=forecast.index.to_timestamp(),
|
106 |
y=mean_forecast,
|
107 |
+
mode="lines",
|
108 |
+
name="Mean Forecast",
|
109 |
+
line=dict(color="red", dash="dash"),
|
110 |
+
)
|
111 |
+
)
|
112 |
|
113 |
# Customize the layout
|
114 |
fig.update_layout(
|
115 |
+
title=f"{df.columns[0]} Forecast",
|
116 |
+
yaxis=dict(title=df.columns[0]),
|
117 |
+
showlegend=True,
|
118 |
+
legend=dict(x=0, y=1, font=dict(size=16)),
|
119 |
+
hovermode="x", # Enable x-axis hover for better interactivity
|
120 |
+
)
|
|
|
121 |
|
122 |
# Return the figure
|
123 |
return fig
|