|
import gradio as gr |
|
import yfinance as yf |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import matplotlib as mpl |
|
import matplotlib.font_manager as fm |
|
import tempfile |
|
import requests |
|
from prophet import Prophet |
|
from datetime import datetime, timedelta |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
def setup_font(): |
|
try: |
|
url_font = "https://drive.google.com/uc?id=1eGAsTN1HBpJAkeVM57_C7ccp7hbgSz3_" |
|
response_font = requests.get(url_font) |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.ttf') as tmp_file: |
|
tmp_file.write(response_font.content) |
|
tmp_file_path = tmp_file.name |
|
|
|
fm.fontManager.addfont(tmp_file_path) |
|
mpl.rc('font', family='Taipei Sans TC Beta') |
|
except Exception as e: |
|
logging.error(f"字體設置失敗: {str(e)}") |
|
|
|
mpl.rc('font', family='SimHei') |
|
|
|
def predict_stock_price(stock_code, period, prediction_days): |
|
try: |
|
|
|
df = yf.download(stock_code, period=period) |
|
if df.empty: |
|
return "無法獲取股票數據", None |
|
|
|
|
|
data = df.reset_index() |
|
data = data[['Date', 'Close']] |
|
data.columns = ['ds', 'y'] |
|
|
|
|
|
model = Prophet(daily_seasonality=True) |
|
model.fit(data) |
|
|
|
|
|
future = model.make_future_dataframe(periods=prediction_days) |
|
forecast = model.predict(future) |
|
|
|
|
|
plt.figure(figsize=(12, 6)) |
|
|
|
|
|
plt.plot(data['ds'], data['y'], |
|
label='實際收盤價', |
|
color='blue') |
|
|
|
|
|
plt.plot(forecast['ds'], forecast['yhat'], |
|
label='預測收盤價', |
|
color='orange', |
|
linestyle='--') |
|
|
|
|
|
plt.fill_between(forecast['ds'], |
|
forecast['yhat_lower'], |
|
forecast['yhat_upper'], |
|
color='orange', |
|
alpha=0.2) |
|
|
|
|
|
plt.title(f'{stock_code} 股價預測', pad=20) |
|
plt.xlabel('日期') |
|
plt.ylabel('股價') |
|
plt.xticks(rotation=45) |
|
plt.legend() |
|
plt.grid(True, linestyle='--', alpha=0.7) |
|
plt.tight_layout() |
|
|
|
|
|
return forecast.tail(prediction_days).to_string(), plt.gcf() |
|
|
|
except Exception as e: |
|
logging.error(f"預測過程發生錯誤: {str(e)}") |
|
return f"預測過程發生錯誤: {str(e)}", None |
|
|
|
|
|
setup_font() |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# 股票價格 Stock Price 預測系統 Prophet") |
|
|
|
with gr.Column(): |
|
|
|
with gr.Row(): |
|
stock_input = gr.Textbox( |
|
label="股票代碼", |
|
placeholder="例如: 2330.TW", |
|
value="2330.TW" |
|
) |
|
|
|
period_dropdown = gr.Dropdown( |
|
choices=["1mo", "3mo", "6mo", "1y", "2y", "5y", "10y", "max"], |
|
label="歷史數據期間", |
|
value="1y" |
|
) |
|
|
|
prediction_days = gr.Slider( |
|
minimum=5, |
|
maximum=30, |
|
value=5, |
|
step=1, |
|
label="預測天數" |
|
) |
|
|
|
predict_button = gr.Button("開始預測", variant="primary") |
|
|
|
|
|
output_plot = gr.Plot(label="股價預測圖") |
|
output_text = gr.Textbox(label="預測結果", lines=10) |
|
|
|
predict_button.click( |
|
predict_stock_price, |
|
inputs=[stock_input, period_dropdown, prediction_days], |
|
outputs=[output_text, output_plot] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |