File size: 3,951 Bytes
804260b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d090fb9
804260b
 
77f9d45
 
d090fb9
804260b
77f9d45
cb77ab1
77f9d45
 
 
 
804260b
77f9d45
 
 
 
804260b
77f9d45
 
 
804260b
 
77f9d45
 
 
804260b
77f9d45
 
 
 
 
804260b
 
 
 
 
 
 
f341752
804260b
f341752
804260b
 
 
 
 
 
 
 
 
 
 
 
77f9d45
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import yfinance as yf
import pandas as pd
import plotly.graph_objects as go
import gradio as gr
import timesfm

# Function to fetch stock data, generate forecast, and create an interactive plot
def stock_forecast(ticker, start_date, end_date):
    try:
        # Fetch historical data
        stock_data = yf.download(ticker, start=start_date, end=end_date)

        # If the DataFrame has a MultiIndex for columns, drop the 'Ticker' level
        if isinstance(stock_data.columns, pd.MultiIndex):
            stock_data.columns = stock_data.columns.droplevel(level=1)

        # Explicitly set column names
        stock_data.columns = ['Close', 'High', 'Low', 'Open', 'Volume']

        # Reset index to have 'Date' as a column
        stock_data.reset_index(inplace=True)

        # Select relevant columns and rename them
        df = stock_data[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'})

        # Ensure the dates are in datetime format
        df['ds'] = pd.to_datetime(df['ds'])

        # Add a unique identifier for the time series
        df['unique_id'] = ticker

        # Initialize the TimesFM model
        tfm = timesfm.TimesFm(
            hparams=timesfm.TimesFmHparams(
                backend="pytorch",
                per_core_batch_size=32,
                horizon_len=30,  # Predicting the next 30 days
                input_patch_len=32,
                output_patch_len=128,
                num_layers=50,
                model_dims=1280,
                use_positional_embedding=False,
            ),
            checkpoint=timesfm.TimesFmCheckpoint(
                huggingface_repo_id="google/timesfm-2.0-500m-pytorch"
            ),
        )

        # Forecast using the prepared DataFrame
        forecast_df = tfm.forecast_on_df(
            inputs=df,
            freq="D",  # Daily frequency
            value_name="y",
            num_jobs=-1,
        )

        # Ensure forecast_df has the correct columns
        forecast_df.rename(columns={"timesfm": "forecast"}, inplace=True)

        # Create an interactive plot with Plotly
        fig = go.Figure()

        # Add Actual Prices Line
        fig.add_trace(go.Scatter(x=df["ds"], y=df["y"], 
                                 mode="lines", name="Actual Prices", 
                                 line=dict(color="cyan", width=2)))

        # Add Forecasted Prices Line
        fig.add_trace(go.Scatter(x=forecast_df["ds"], y=forecast_df["forecast"], 
                                 mode="lines", name="Forecasted Prices", 
                                 line=dict(color="magenta", width=2, dash="dash")))

        # Layout Customization
        fig.update_layout(
            title=f"{ticker} Stock Price Forecast (Interactive)",
            xaxis_title="Date",
            yaxis_title="Price",
            template="plotly_dark",  # Dark Theme
            hovermode="x unified",  # Show all values on hover
            legend=dict(bgcolor="black", bordercolor="white"),
        )

        return fig  # Return the Plotly figure for Gradio
    
    except Exception as e:
        return f"Error: {str(e)}"

# Create Gradio interface with an "Enter" button
with gr.Blocks() as demo:
    gr.Markdown("# Stock Price Forecast App")
    gr.Markdown("Enter a stock ticker, start date, and end date to visualize historical and forecasted stock prices.")

    with gr.Row():
        ticker_input = gr.Textbox(label="Enter Stock Ticker", value="NVDA")
        start_date_input = gr.Textbox(label="Enter Start Date (YYYY-MM-DD)", value="2022-01-01")
        end_date_input = gr.Textbox(label="Enter End Date (YYYY-MM-DD)", value="2026-01-01")
    
    submit_button = gr.Button("Enter")
    plot_output = gr.Plot()

    # Link the button to the function
    submit_button.click(
        stock_forecast,
        inputs=[ticker_input, start_date_input, end_date_input],
        outputs=plot_output
    )

# Launch the Gradio app
demo.launch()