File size: 9,005 Bytes
e3003c7
 
 
 
 
 
 
 
 
c3585da
 
e3003c7
 
 
 
 
 
4fc19bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3003c7
4fc19bf
e3003c7
 
 
 
 
 
 
4fc19bf
 
 
e3003c7
4fc19bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3585da
e3003c7
4fc19bf
e3003c7
 
 
4fc19bf
e3003c7
 
 
4fc19bf
 
e3003c7
 
 
 
 
 
 
 
d460df7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import gradio as gr
import torch
from chronos import ChronosPipeline
import yfinance as yf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from sklearn.metrics import mean_absolute_error, mean_squared_error
import tempfile

def get_popular_tickers():
    return [
        "AAPL", "MSFT", "GOOGL", "AMZN", "META", "TSLA", "NVDA", "JPM",
        "JNJ", "V", "PG", "WMT", "BAC", "DIS", "NFLX", "INTC"
    ]

def predict_stock(ticker, train_data_points, prediction_days):
    try:
        # Asegurar que los par谩metros sean enteros
        train_data_points = int(train_data_points)
        prediction_days = int(prediction_days)
        
        # Configurar el pipeline
        pipeline = ChronosPipeline.from_pretrained(
            "amazon/chronos-t5-mini",
            device_map="cpu",
            torch_dtype=torch.float32
        )
        
        # Obtener datos hist贸ricos
        stock = yf.Ticker(ticker)
        hist = stock.history(period="max")
        if hist.empty:
            raise ValueError(f"No hay datos disponibles para {ticker}")
            
        stock_prices = hist[['Close']].reset_index()
        df = stock_prices.rename(columns={'Date': 'Date', 'Close': f'{ticker}_Close'})
        
        total_points = len(df)
        if total_points < 50:
            raise ValueError(f"Datos insuficientes para {ticker}")
        
        # Asegurar que el n煤mero de datos de entrenamiento no exceda el total disponible
        train_data_points = min(train_data_points, total_points)
        
        # Crear el contexto para entrenamiento
        context = torch.tensor(df[f'{ticker}_Close'][:train_data_points].values, dtype=torch.float32)
        
        # Realizar predicci贸n
        forecast = pipeline.predict(context, prediction_days, limit_prediction_length=False)
        low, median, high = np.quantile(forecast[0].numpy(), [0.01, 0.5, 0.99], axis=0)
        
        plt.figure(figsize=(20, 10))
        plt.clf()
        
        # Determinar el rango de fechas para mostrar
        context_days = min(10, train_data_points)
        start_index = max(0, train_data_points - context_days)
        end_index = min(train_data_points + prediction_days, total_points)
        
        # Plotear datos hist贸ricos
        historical_dates = df['Date'][start_index:end_index]
        historical_data = df[f'{ticker}_Close'][start_index:end_index].values
        plt.plot(historical_dates,
                 historical_data,
                 color='blue',
                 linewidth=2,
                 label='Datos Reales')
        
        # Crear fechas para la predicci贸n
        if train_data_points < total_points:
            prediction_start_date = df['Date'].iloc[train_data_points]
        else:
            last_date = df['Date'].iloc[-1]
            prediction_start_date = last_date + pd.Timedelta(days=1)
        
        prediction_dates = pd.date_range(start=prediction_start_date, periods=prediction_days, freq='B')
        
        # Plotear predicci贸n
        plt.plot(prediction_dates,
                 median,
                 color='black',
                 linewidth=2,
                 linestyle='-',
                 label='Predicci贸n')
        
        # 脕rea de confianza
        plt.fill_between(prediction_dates, low, high,
                         color='gray', alpha=0.2,
                         label='Intervalo de Confianza')
        
        # Calcular m茅tricas si hay datos reales para comparar
        overlap_end_index = train_data_points + prediction_days
        if overlap_end_index <= total_points:
            real_future_dates = df['Date'][train_data_points:overlap_end_index]
            real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values
            
            matching_dates = real_future_dates[real_future_dates.isin(prediction_dates)]
            matching_indices = matching_dates.index - train_data_points
            plt.plot(matching_dates,
                     real_future_data[matching_indices],
                     color='red',
                     linewidth=2,
                     linestyle='--',
                     label='Datos Reales de Validaci贸n')
            
            predicted_data = median[:len(matching_indices)]
            mae = mean_absolute_error(real_future_data[matching_indices], predicted_data)
            rmse = np.sqrt(mean_squared_error(real_future_data[matching_indices], predicted_data))
            mape = np.mean(np.abs((real_future_data[matching_indices] - predicted_data) / real_future_data[matching_indices])) * 100
            plt.title(f"Predicci贸n del Precio de {ticker}\nMAE: {mae:.2f} | RMSE: {rmse:.2f} | MAPE: {mape:.2f}%",
                      fontsize=14, pad=20)
        else:
            plt.title(f"Predicci贸n Futura del Precio de {ticker}",
                      fontsize=14, pad=20)
        
        plt.legend(loc="upper left", fontsize=12)
        plt.xlabel("Fecha", fontsize=12)
        plt.ylabel("Precio", fontsize=12)
        
        plt.grid(True, which='both', axis='x', linestyle='--', linewidth=0.5)
        
        ax = plt.gca()
        locator = mdates.DayLocator()
        formatter = mdates.DateFormatter('%Y-%m-%d')
        ax.xaxis.set_major_locator(locator)
        ax.xaxis.set_major_formatter(formatter)
        
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
        
        plt.tight_layout()
        
        # Crear archivo CSV temporal
        temp_csv = tempfile.NamedTemporaryFile(delete=False, suffix='.csv')
        prediction_df = pd.DataFrame({
            'Date': prediction_dates,
            'Predicted_Price': median,
            'Lower_Bound': low,
            'Upper_Bound': high
        })
        
        if overlap_end_index <= total_points:
            real_future_dates = df['Date'][train_data_points:overlap_end_index]
            real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values
            matching_dates = real_future_dates[real_future_dates.isin(prediction_dates)]
            prediction_df = prediction_df[prediction_df['Date'].isin(matching_dates)]
            prediction_df['Real_Price'] = real_future_data[:len(prediction_df)]
        
        prediction_df.to_csv(temp_csv.name, index=False)
        temp_csv.close()
        
        return plt, temp_csv.name
            
    except Exception as e:
        print(f"Error: {str(e)}")
        raise gr.Error(f"Error al procesar {ticker}: {str(e)}")

def update_train_data_points(ticker):
    if not ticker:
        return gr.Slider.update(value=1000, maximum=5000)
    
    try:
        stock = yf.Ticker(ticker)
        hist = stock.history(period="max")
        if hist.empty:
            raise ValueError(f"No hay datos disponibles para {ticker}")
            
        total_points = len(hist)
        if total_points < 50:
            raise ValueError(f"Datos insuficientes para {ticker}")
            
        return gr.Slider.update(
            maximum=total_points,
            value=min(1000, total_points),
            minimum=50,
            step=1,
            interactive=True
        )
    except Exception as e:
        print(f"Error al actualizar datos para {ticker}: {str(e)}")
        return gr.Slider.update(value=1000, maximum=5000, minimum=50, step=1)

# Interfaz de Gradio
with gr.Blocks() as demo:
    gr.Markdown("# Aplicaci贸n de Predicci贸n de Precios de Acciones")
    
    with gr.Row():
        with gr.Column(scale=1):
            ticker = gr.Dropdown(
                choices=get_popular_tickers(),
                value="AAPL",
                label="Selecciona el S铆mbolo de la Acci贸n",
                interactive=True
            )
            with gr.Column():
                train_data_points = gr.Slider(
                    minimum=50,
                    maximum=5000,
                    value=1000,
                    step=1,
                    label="N煤mero de Datos para Entrenamiento",
                    interactive=True
                )
                prediction_days = gr.Slider(
                    minimum=1,
                    maximum=60,
                    value=5,
                    step=1,
                    label="N煤mero de D铆as a Predecir",
                    interactive=True
                )
            predict_btn = gr.Button("Predecir", interactive=True)
        
    with gr.Column():
        error_output = gr.Textbox(label="Estado", visible=False)
        plot_output = gr.Plot(label="Gr谩fico de Predicci贸n")
        download_btn = gr.File(label="Descargar Predicciones")
    
    # Eventos
    ticker.change(
        fn=update_train_data_points,
        inputs=[ticker],
        outputs=[train_data_points],
        api_name="update_data"
    )
    
    predict_btn.click(
        fn=predict_stock,
        inputs=[ticker, train_data_points, prediction_days],
        outputs=[plot_output, download_btn]
    )

demo.launch()