YF_Forecast / app.py
JayLacoma's picture
Update app.py
f341752 verified
import yfinance as yf
import pandas as pd
import plotly.graph_objects as go
import gradio as gr
import timesfm
# Function to fetch stock data, generate forecast, and create an interactive plot
def stock_forecast(ticker, start_date, end_date):
try:
# Fetch historical data
stock_data = yf.download(ticker, start=start_date, end=end_date)
# If the DataFrame has a MultiIndex for columns, drop the 'Ticker' level
if isinstance(stock_data.columns, pd.MultiIndex):
stock_data.columns = stock_data.columns.droplevel(level=1)
# Explicitly set column names
stock_data.columns = ['Close', 'High', 'Low', 'Open', 'Volume']
# Reset index to have 'Date' as a column
stock_data.reset_index(inplace=True)
# Select relevant columns and rename them
df = stock_data[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'})
# Ensure the dates are in datetime format
df['ds'] = pd.to_datetime(df['ds'])
# Add a unique identifier for the time series
df['unique_id'] = ticker
# Initialize the TimesFM model
tfm = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
backend="pytorch",
per_core_batch_size=32,
horizon_len=30, # Predicting the next 30 days
input_patch_len=32,
output_patch_len=128,
num_layers=50,
model_dims=1280,
use_positional_embedding=False,
),
checkpoint=timesfm.TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-2.0-500m-pytorch"
),
)
# Forecast using the prepared DataFrame
forecast_df = tfm.forecast_on_df(
inputs=df,
freq="D", # Daily frequency
value_name="y",
num_jobs=-1,
)
# Ensure forecast_df has the correct columns
forecast_df.rename(columns={"timesfm": "forecast"}, inplace=True)
# Create an interactive plot with Plotly
fig = go.Figure()
# Add Actual Prices Line
fig.add_trace(go.Scatter(x=df["ds"], y=df["y"],
mode="lines", name="Actual Prices",
line=dict(color="cyan", width=2)))
# Add Forecasted Prices Line
fig.add_trace(go.Scatter(x=forecast_df["ds"], y=forecast_df["forecast"],
mode="lines", name="Forecasted Prices",
line=dict(color="magenta", width=2, dash="dash")))
# Layout Customization
fig.update_layout(
title=f"{ticker} Stock Price Forecast (Interactive)",
xaxis_title="Date",
yaxis_title="Price",
template="plotly_dark", # Dark Theme
hovermode="x unified", # Show all values on hover
legend=dict(bgcolor="black", bordercolor="white"),
)
return fig # Return the Plotly figure for Gradio
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface with an "Enter" button
with gr.Blocks() as demo:
gr.Markdown("# Stock Price Forecast App")
gr.Markdown("Enter a stock ticker, start date, and end date to visualize historical and forecasted stock prices.")
with gr.Row():
ticker_input = gr.Textbox(label="Enter Stock Ticker", value="NVDA")
start_date_input = gr.Textbox(label="Enter Start Date (YYYY-MM-DD)", value="2022-01-01")
end_date_input = gr.Textbox(label="Enter End Date (YYYY-MM-DD)", value="2026-01-01")
submit_button = gr.Button("Enter")
plot_output = gr.Plot()
# Link the button to the function
submit_button.click(
stock_forecast,
inputs=[ticker_input, start_date_input, end_date_input],
outputs=plot_output
)
# Launch the Gradio app
demo.launch()