Spaces:
Running
Running
import yfinance as yf | |
import pandas as pd | |
import plotly.graph_objects as go | |
import gradio as gr | |
import timesfm | |
# Function to fetch stock data, generate forecast, and create an interactive plot | |
def stock_forecast(ticker, start_date, end_date): | |
try: | |
# Fetch historical data | |
stock_data = yf.download(ticker, start=start_date, end=end_date) | |
# If the DataFrame has a MultiIndex for columns, drop the 'Ticker' level | |
if isinstance(stock_data.columns, pd.MultiIndex): | |
stock_data.columns = stock_data.columns.droplevel(level=1) | |
# Explicitly set column names | |
stock_data.columns = ['Close', 'High', 'Low', 'Open', 'Volume'] | |
# Reset index to have 'Date' as a column | |
stock_data.reset_index(inplace=True) | |
# Select relevant columns and rename them | |
df = stock_data[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'}) | |
# Ensure the dates are in datetime format | |
df['ds'] = pd.to_datetime(df['ds']) | |
# Add a unique identifier for the time series | |
df['unique_id'] = ticker | |
# Initialize the TimesFM model | |
tfm = timesfm.TimesFm( | |
hparams=timesfm.TimesFmHparams( | |
backend="pytorch", | |
per_core_batch_size=32, | |
horizon_len=30, # Predicting the next 30 days | |
input_patch_len=32, | |
output_patch_len=128, | |
num_layers=50, | |
model_dims=1280, | |
use_positional_embedding=False, | |
), | |
checkpoint=timesfm.TimesFmCheckpoint( | |
huggingface_repo_id="google/timesfm-2.0-500m-pytorch" | |
), | |
) | |
# Forecast using the prepared DataFrame | |
forecast_df = tfm.forecast_on_df( | |
inputs=df, | |
freq="D", # Daily frequency | |
value_name="y", | |
num_jobs=-1, | |
) | |
# Ensure forecast_df has the correct columns | |
forecast_df.rename(columns={"timesfm": "forecast"}, inplace=True) | |
# Create an interactive plot with Plotly | |
fig = go.Figure() | |
# Add Actual Prices Line | |
fig.add_trace(go.Scatter(x=df["ds"], y=df["y"], | |
mode="lines", name="Actual Prices", | |
line=dict(color="cyan", width=2))) | |
# Add Forecasted Prices Line | |
fig.add_trace(go.Scatter(x=forecast_df["ds"], y=forecast_df["forecast"], | |
mode="lines", name="Forecasted Prices", | |
line=dict(color="magenta", width=2, dash="dash"))) | |
# Layout Customization | |
fig.update_layout( | |
title=f"{ticker} Stock Price Forecast (Interactive)", | |
xaxis_title="Date", | |
yaxis_title="Price", | |
template="plotly_dark", # Dark Theme | |
hovermode="x unified", # Show all values on hover | |
legend=dict(bgcolor="black", bordercolor="white"), | |
) | |
return fig # Return the Plotly figure for Gradio | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Create Gradio interface with an "Enter" button | |
with gr.Blocks() as demo: | |
gr.Markdown("# Stock Price Forecast App") | |
gr.Markdown("Enter a stock ticker, start date, and end date to visualize historical and forecasted stock prices.") | |
with gr.Row(): | |
ticker_input = gr.Textbox(label="Enter Stock Ticker", value="NVDA") | |
start_date_input = gr.Textbox(label="Enter Start Date (YYYY-MM-DD)", value="2022-01-01") | |
end_date_input = gr.Textbox(label="Enter End Date (YYYY-MM-DD)", value="2026-01-01") | |
submit_button = gr.Button("Enter") | |
plot_output = gr.Plot() | |
# Link the button to the function | |
submit_button.click( | |
stock_forecast, | |
inputs=[ticker_input, start_date_input, end_date_input], | |
outputs=plot_output | |
) | |
# Launch the Gradio app | |
demo.launch() | |