Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from chronos import ChronosPipeline | |
import yfinance as yf | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.dates as mdates | |
from sklearn.metrics import mean_absolute_error, mean_squared_error | |
import tempfile | |
def get_popular_tickers(): | |
return [ | |
"AAPL", "MSFT", "GOOGL", "AMZN", "META", "TSLA", "NVDA", "JPM", | |
"JNJ", "V", "PG", "WMT", "BAC", "DIS", "NFLX", "INTC" | |
] | |
def predict_stock(ticker, train_data_points, prediction_days): | |
try: | |
# Asegurar que los par谩metros sean enteros | |
train_data_points = int(train_data_points) | |
prediction_days = int(prediction_days) | |
# Configurar el pipeline | |
pipeline = ChronosPipeline.from_pretrained( | |
"amazon/chronos-t5-mini", | |
device_map="cpu", | |
torch_dtype=torch.float32 | |
) | |
# Obtener datos hist贸ricos | |
stock = yf.Ticker(ticker) | |
hist = stock.history(period="max") | |
if hist.empty: | |
raise ValueError(f"No hay datos disponibles para {ticker}") | |
stock_prices = hist[['Close']].reset_index() | |
df = stock_prices.rename(columns={'Date': 'Date', 'Close': f'{ticker}_Close'}) | |
total_points = len(df) | |
if total_points < 50: | |
raise ValueError(f"Datos insuficientes para {ticker}") | |
# Asegurar que el n煤mero de datos de entrenamiento no exceda el total disponible | |
train_data_points = min(train_data_points, total_points) | |
# Crear el contexto para entrenamiento | |
context = torch.tensor(df[f'{ticker}_Close'][:train_data_points].values, dtype=torch.float32) | |
# Realizar predicci贸n | |
forecast = pipeline.predict(context, prediction_days, limit_prediction_length=False) | |
low, median, high = np.quantile(forecast[0].numpy(), [0.01, 0.5, 0.99], axis=0) | |
plt.figure(figsize=(20, 10)) | |
plt.clf() | |
# Determinar el rango de fechas para mostrar | |
context_days = min(10, train_data_points) | |
start_index = max(0, train_data_points - context_days) | |
end_index = min(train_data_points + prediction_days, total_points) | |
# Plotear datos hist贸ricos | |
historical_dates = df['Date'][start_index:end_index] | |
historical_data = df[f'{ticker}_Close'][start_index:end_index].values | |
plt.plot(historical_dates, | |
historical_data, | |
color='blue', | |
linewidth=2, | |
label='Datos Reales') | |
# Crear fechas para la predicci贸n | |
if train_data_points < total_points: | |
prediction_start_date = df['Date'].iloc[train_data_points] | |
else: | |
last_date = df['Date'].iloc[-1] | |
prediction_start_date = last_date + pd.Timedelta(days=1) | |
prediction_dates = pd.date_range(start=prediction_start_date, periods=prediction_days, freq='B') | |
# Plotear predicci贸n | |
plt.plot(prediction_dates, | |
median, | |
color='black', | |
linewidth=2, | |
linestyle='-', | |
label='Predicci贸n') | |
# 脕rea de confianza | |
plt.fill_between(prediction_dates, low, high, | |
color='gray', alpha=0.2, | |
label='Intervalo de Confianza') | |
# Calcular m茅tricas si hay datos reales para comparar | |
overlap_end_index = train_data_points + prediction_days | |
if overlap_end_index <= total_points: | |
real_future_dates = df['Date'][train_data_points:overlap_end_index] | |
real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values | |
matching_dates = real_future_dates[real_future_dates.isin(prediction_dates)] | |
matching_indices = matching_dates.index - train_data_points | |
plt.plot(matching_dates, | |
real_future_data[matching_indices], | |
color='red', | |
linewidth=2, | |
linestyle='--', | |
label='Datos Reales de Validaci贸n') | |
predicted_data = median[:len(matching_indices)] | |
mae = mean_absolute_error(real_future_data[matching_indices], predicted_data) | |
rmse = np.sqrt(mean_squared_error(real_future_data[matching_indices], predicted_data)) | |
mape = np.mean(np.abs((real_future_data[matching_indices] - predicted_data) / real_future_data[matching_indices])) * 100 | |
plt.title(f"Predicci贸n del Precio de {ticker}\nMAE: {mae:.2f} | RMSE: {rmse:.2f} | MAPE: {mape:.2f}%", | |
fontsize=14, pad=20) | |
else: | |
plt.title(f"Predicci贸n Futura del Precio de {ticker}", | |
fontsize=14, pad=20) | |
plt.legend(loc="upper left", fontsize=12) | |
plt.xlabel("Fecha", fontsize=12) | |
plt.ylabel("Precio", fontsize=12) | |
plt.grid(True, which='both', axis='x', linestyle='--', linewidth=0.5) | |
ax = plt.gca() | |
locator = mdates.DayLocator() | |
formatter = mdates.DateFormatter('%Y-%m-%d') | |
ax.xaxis.set_major_locator(locator) | |
ax.xaxis.set_major_formatter(formatter) | |
plt.setp(ax.get_xticklabels(), rotation=45, ha='right') | |
plt.tight_layout() | |
# Crear archivo CSV temporal | |
temp_csv = tempfile.NamedTemporaryFile(delete=False, suffix='.csv') | |
prediction_df = pd.DataFrame({ | |
'Date': prediction_dates, | |
'Predicted_Price': median, | |
'Lower_Bound': low, | |
'Upper_Bound': high | |
}) | |
if overlap_end_index <= total_points: | |
real_future_dates = df['Date'][train_data_points:overlap_end_index] | |
real_future_data = df[f'{ticker}_Close'][train_data_points:overlap_end_index].values | |
matching_dates = real_future_dates[real_future_dates.isin(prediction_dates)] | |
prediction_df = prediction_df[prediction_df['Date'].isin(matching_dates)] | |
prediction_df['Real_Price'] = real_future_data[:len(prediction_df)] | |
prediction_df.to_csv(temp_csv.name, index=False) | |
temp_csv.close() | |
return plt, temp_csv.name | |
except Exception as e: | |
print(f"Error: {str(e)}") | |
raise gr.Error(f"Error al procesar {ticker}: {str(e)}") | |
def update_train_data_points(ticker): | |
if not ticker: | |
return gr.Slider.update(value=1000, maximum=5000) | |
try: | |
stock = yf.Ticker(ticker) | |
hist = stock.history(period="max") | |
if hist.empty: | |
raise ValueError(f"No hay datos disponibles para {ticker}") | |
total_points = len(hist) | |
if total_points < 50: | |
raise ValueError(f"Datos insuficientes para {ticker}") | |
return gr.Slider.update( | |
maximum=total_points, | |
value=min(1000, total_points), | |
minimum=50, | |
step=1, | |
interactive=True | |
) | |
except Exception as e: | |
print(f"Error al actualizar datos para {ticker}: {str(e)}") | |
return gr.Slider.update(value=1000, maximum=5000, minimum=50, step=1) | |
# Interfaz de Gradio | |
with gr.Blocks() as demo: | |
gr.Markdown("# Aplicaci贸n de Predicci贸n de Precios de Acciones") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
ticker = gr.Dropdown( | |
choices=get_popular_tickers(), | |
value="AAPL", | |
label="Selecciona el S铆mbolo de la Acci贸n", | |
interactive=True | |
) | |
with gr.Column(): | |
train_data_points = gr.Slider( | |
minimum=50, | |
maximum=5000, | |
value=1000, | |
step=1, | |
label="N煤mero de Datos para Entrenamiento", | |
interactive=True | |
) | |
prediction_days = gr.Slider( | |
minimum=1, | |
maximum=60, | |
value=5, | |
step=1, | |
label="N煤mero de D铆as a Predecir", | |
interactive=True | |
) | |
predict_btn = gr.Button("Predecir", interactive=True) | |
with gr.Column(): | |
error_output = gr.Textbox(label="Estado", visible=False) | |
plot_output = gr.Plot(label="Gr谩fico de Predicci贸n") | |
download_btn = gr.File(label="Descargar Predicciones") | |
# Eventos | |
ticker.change( | |
fn=update_train_data_points, | |
inputs=[ticker], | |
outputs=[train_data_points], | |
api_name="update_data" | |
) | |
predict_btn.click( | |
fn=predict_stock, | |
inputs=[ticker, train_data_points, prediction_days], | |
outputs=[plot_output, download_btn] | |
) | |
demo.launch() |