Spaces:
Running
Running
File size: 3,951 Bytes
804260b d090fb9 804260b 77f9d45 d090fb9 804260b 77f9d45 cb77ab1 77f9d45 804260b 77f9d45 804260b 77f9d45 804260b 77f9d45 804260b 77f9d45 804260b f341752 804260b f341752 804260b 77f9d45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import yfinance as yf
import pandas as pd
import plotly.graph_objects as go
import gradio as gr
import timesfm
# Function to fetch stock data, generate forecast, and create an interactive plot
def stock_forecast(ticker, start_date, end_date):
try:
# Fetch historical data
stock_data = yf.download(ticker, start=start_date, end=end_date)
# If the DataFrame has a MultiIndex for columns, drop the 'Ticker' level
if isinstance(stock_data.columns, pd.MultiIndex):
stock_data.columns = stock_data.columns.droplevel(level=1)
# Explicitly set column names
stock_data.columns = ['Close', 'High', 'Low', 'Open', 'Volume']
# Reset index to have 'Date' as a column
stock_data.reset_index(inplace=True)
# Select relevant columns and rename them
df = stock_data[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'})
# Ensure the dates are in datetime format
df['ds'] = pd.to_datetime(df['ds'])
# Add a unique identifier for the time series
df['unique_id'] = ticker
# Initialize the TimesFM model
tfm = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
backend="pytorch",
per_core_batch_size=32,
horizon_len=30, # Predicting the next 30 days
input_patch_len=32,
output_patch_len=128,
num_layers=50,
model_dims=1280,
use_positional_embedding=False,
),
checkpoint=timesfm.TimesFmCheckpoint(
huggingface_repo_id="google/timesfm-2.0-500m-pytorch"
),
)
# Forecast using the prepared DataFrame
forecast_df = tfm.forecast_on_df(
inputs=df,
freq="D", # Daily frequency
value_name="y",
num_jobs=-1,
)
# Ensure forecast_df has the correct columns
forecast_df.rename(columns={"timesfm": "forecast"}, inplace=True)
# Create an interactive plot with Plotly
fig = go.Figure()
# Add Actual Prices Line
fig.add_trace(go.Scatter(x=df["ds"], y=df["y"],
mode="lines", name="Actual Prices",
line=dict(color="cyan", width=2)))
# Add Forecasted Prices Line
fig.add_trace(go.Scatter(x=forecast_df["ds"], y=forecast_df["forecast"],
mode="lines", name="Forecasted Prices",
line=dict(color="magenta", width=2, dash="dash")))
# Layout Customization
fig.update_layout(
title=f"{ticker} Stock Price Forecast (Interactive)",
xaxis_title="Date",
yaxis_title="Price",
template="plotly_dark", # Dark Theme
hovermode="x unified", # Show all values on hover
legend=dict(bgcolor="black", bordercolor="white"),
)
return fig # Return the Plotly figure for Gradio
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface with an "Enter" button
with gr.Blocks() as demo:
gr.Markdown("# Stock Price Forecast App")
gr.Markdown("Enter a stock ticker, start date, and end date to visualize historical and forecasted stock prices.")
with gr.Row():
ticker_input = gr.Textbox(label="Enter Stock Ticker", value="NVDA")
start_date_input = gr.Textbox(label="Enter Start Date (YYYY-MM-DD)", value="2022-01-01")
end_date_input = gr.Textbox(label="Enter End Date (YYYY-MM-DD)", value="2026-01-01")
submit_button = gr.Button("Enter")
plot_output = gr.Plot()
# Link the button to the function
submit_button.click(
stock_forecast,
inputs=[ticker_input, start_date_input, end_date_input],
outputs=plot_output
)
# Launch the Gradio app
demo.launch()
|