import yfinance as yf import pandas as pd import plotly.graph_objects as go import gradio as gr import timesfm # Function to fetch stock data, generate forecast, and create an interactive plot def stock_forecast(ticker, start_date, end_date): try: # Fetch historical data stock_data = yf.download(ticker, start=start_date, end=end_date) # If the DataFrame has a MultiIndex for columns, drop the 'Ticker' level if isinstance(stock_data.columns, pd.MultiIndex): stock_data.columns = stock_data.columns.droplevel(level=1) # Explicitly set column names stock_data.columns = ['Close', 'High', 'Low', 'Open', 'Volume'] # Reset index to have 'Date' as a column stock_data.reset_index(inplace=True) # Select relevant columns and rename them df = stock_data[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'}) # Ensure the dates are in datetime format df['ds'] = pd.to_datetime(df['ds']) # Add a unique identifier for the time series df['unique_id'] = ticker # Initialize the TimesFM model tfm = timesfm.TimesFm( hparams=timesfm.TimesFmHparams( backend="pytorch", per_core_batch_size=32, horizon_len=30, # Predicting the next 30 days input_patch_len=32, output_patch_len=128, num_layers=50, model_dims=1280, use_positional_embedding=False, ), checkpoint=timesfm.TimesFmCheckpoint( huggingface_repo_id="google/timesfm-2.0-500m-pytorch" ), ) # Forecast using the prepared DataFrame forecast_df = tfm.forecast_on_df( inputs=df, freq="D", # Daily frequency value_name="y", num_jobs=-1, ) # Ensure forecast_df has the correct columns forecast_df.rename(columns={"timesfm": "forecast"}, inplace=True) # Create an interactive plot with Plotly fig = go.Figure() # Add Actual Prices Line fig.add_trace(go.Scatter(x=df["ds"], y=df["y"], mode="lines", name="Actual Prices", line=dict(color="cyan", width=2))) # Add Forecasted Prices Line fig.add_trace(go.Scatter(x=forecast_df["ds"], y=forecast_df["forecast"], mode="lines", name="Forecasted Prices", line=dict(color="magenta", width=2, dash="dash"))) # Layout Customization fig.update_layout( title=f"{ticker} Stock Price Forecast (Interactive)", xaxis_title="Date", yaxis_title="Price", template="plotly_dark", # Dark Theme hovermode="x unified", # Show all values on hover legend=dict(bgcolor="black", bordercolor="white"), ) return fig # Return the Plotly figure for Gradio except Exception as e: return f"Error: {str(e)}" # Create Gradio interface with an "Enter" button with gr.Blocks() as demo: gr.Markdown("# Stock Price Forecast App") gr.Markdown("Enter a stock ticker, start date, and end date to visualize historical and forecasted stock prices.") with gr.Row(): ticker_input = gr.Textbox(label="Enter Stock Ticker", value="NVDA") start_date_input = gr.Textbox(label="Enter Start Date (YYYY-MM-DD)", value="2022-01-01") end_date_input = gr.Textbox(label="Enter End Date (YYYY-MM-DD)", value="2026-01-01") submit_button = gr.Button("Enter") plot_output = gr.Plot() # Link the button to the function submit_button.click( stock_forecast, inputs=[ticker_input, start_date_input, end_date_input], outputs=plot_output ) # Launch the Gradio app demo.launch()