seawolf2357 commited on
Commit
a805914
ยท
verified ยท
1 Parent(s): 2217397

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -73
app.py CHANGED
@@ -2,6 +2,11 @@ from huggingface_hub import InferenceClient
2
  import gradio as gr
3
  from transformers import GPT2Tokenizer
4
  import yfinance as yf
 
 
 
 
 
5
 
6
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
7
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
@@ -49,9 +54,37 @@ def format_prompt(message, history):
49
 
50
  def get_stock_data(ticker):
51
  stock = yf.Ticker(ticker)
52
- hist = stock.history(period="5d") # ์ง€๋‚œ 5์ผ๊ฐ„์˜ ์ฃผ์‹ ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
53
  return hist
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def generate(prompt, history=[], temperature=0.1, max_new_tokens=10000, top_p=0.95, repetition_penalty=1.0):
56
  global total_tokens_used
57
  input_tokens = len(tokenizer.encode(prompt))
@@ -64,82 +97,17 @@ def generate(prompt, history=[], temperature=0.1, max_new_tokens=10000, top_p=0.
64
  formatted_prompt = format_prompt(prompt, history)
65
  output_accumulated = ""
66
  try:
67
- # ํ‹ฐ์ปค ํ™•์ธ ๋ฐ ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘
68
- stock_info = get_stock_info(prompt) # ์ข…๋ชฉ๋ช…์„ ํ† ๋Œ€๋กœ ํ‹ฐ์ปค ์ •๋ณด์™€ ๊ธฐ์—… ์„ค๋ช…์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
69
- if stock_info['ticker']:
70
- response_msg = f"{stock_info['name']}์€(๋Š”) {stock_info['description']} ์ฃผ๋ ฅ์œผ๋กœ ์ƒ์‚ฐํ•˜๋Š” ๊ธฐ์—…์ž…๋‹ˆ๋‹ค. {stock_info['name']}์˜ ํ‹ฐ์ปค๋Š” {stock_info['ticker']}์ž…๋‹ˆ๋‹ค. ์›ํ•˜์‹œ๋Š” ์ข…๋ชฉ์ด ๋งž๋Š”๊ฐ€์š”?"
71
- output_accumulated += response_msg
72
- yield output_accumulated
73
-
74
- # ์ถ”๊ฐ€์ ์ธ ๋ถ„์„ ์š”์ฒญ์ด ์žˆ๋‹ค๋ฉด, yfinance๋กœ ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘ ๋ฐ ๋ถ„์„
75
- stock_data = get_stock_data(stock_info['ticker']) # ํ‹ฐ์ปค๋ฅผ ์ด์šฉํ•ด ์ฃผ์‹ ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
76
- stream = client.text_generation(
77
- formatted_prompt,
78
- temperature=temperature,
79
- max_new_tokens=min(max_new_tokens, available_tokens),
80
- top_p=top_p,
81
- repetition_penalty=repetition_penalty,
82
- do_sample=True,
83
- seed=42,
84
- stream=True
85
- )
86
- for response in stream:
87
- output_part = response['generated_text'] if 'generated_text' in response else str(response)
88
- output_accumulated += output_part
89
- yield output_accumulated + f"\n\n---\nTotal tokens used: {total_tokens_used}\nStock Data: {stock_data}"
90
  else:
91
- # ์ž…๋ ฅ์ด ํ‹ฐ์ปค์ธ ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ
92
- ticker = prompt.upper()
93
- if ticker in ['AAPL', 'MSFT', 'AMZN', 'GOOGL', 'TSLA']:
94
- stock_info = get_stock_info_by_ticker(ticker)
95
- response_msg = f"{stock_info['name']}์€(๋Š”) {stock_info['description']} ์ฃผ๋ ฅ์œผ๋กœ ์ƒ์‚ฐํ•˜๋Š” ๊ธฐ์—…์ž…๋‹ˆ๋‹ค. {stock_info['name']}์˜ ํ‹ฐ์ปค๋Š” {stock_info['ticker']}์ž…๋‹ˆ๋‹ค. ์›ํ•˜์‹œ๋Š” ์ข…๋ชฉ์ด ๋งž๋Š”๊ฐ€์š”?"
96
- output_accumulated += response_msg
97
- yield output_accumulated
98
-
99
- # ์ถ”๊ฐ€์ ์ธ ๋ถ„์„ ์š”์ฒญ์ด ์žˆ๋‹ค๋ฉด, yfinance๋กœ ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘ ๋ฐ ๋ถ„์„
100
- stock_data = get_stock_data(stock_info['ticker']) # ํ‹ฐ์ปค๋ฅผ ์ด์šฉํ•ด ์ฃผ์‹ ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
101
- stream = client.text_generation(
102
- formatted_prompt,
103
- temperature=temperature,
104
- max_new_tokens=min(max_new_tokens, available_tokens),
105
- top_p=top_p,
106
- repetition_penalty=repetition_penalty,
107
- do_sample=True,
108
- seed=42,
109
- stream=True
110
- )
111
- for response in stream:
112
- output_part = response['generated_text'] if 'generated_text' in response else str(response)
113
- output_accumulated += output_part
114
- yield output_accumulated + f"\n\n---\nTotal tokens used: {total_tokens_used}\nStock Data: {stock_data}"
115
- else:
116
- yield f"์ž…๋ ฅํ•˜์‹  '{prompt}'์€(๋Š”) ์ง€์›๋˜๋Š” ์ข…๋ชฉ๋ช… ๋˜๋Š” ํ‹ฐ์ปค๊ฐ€ ์•„๋‹™๋‹ˆ๋‹ค. ํ˜„์žฌ ์ง€์›๋˜๋Š” ์ข…๋ชฉ์€ ์• ํ”Œ(AAPL), ๋งˆ์ดํฌ๋กœ์†Œํ”„ํŠธ(MSFT), ์•„๋งˆ์กด(AMZN), ์•ŒํŒŒ๋ฒณ(GOOGL), ํ…Œ์Šฌ๋ผ(TSLA) ๋“ฑ์ž…๋‹ˆ๋‹ค. ์ •ํ™•ํ•œ ์ข…๋ชฉ๋ช… ๋˜๋Š” ํ‹ฐ์ปค๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”."
117
  except Exception as e:
118
  yield f"Error: {str(e)}\nTotal tokens used: {total_tokens_used}"
119
 
120
- # ํ‹ฐ์ปค๋ฅผ ํ† ๋Œ€๋กœ ์ข…๋ชฉ ์ •๋ณด๋ฅผ ์ œ๊ณตํ•˜๋Š” ํ•จ์ˆ˜
121
- def get_stock_info_by_ticker(ticker):
122
- stock_info = {
123
- "AAPL": {'ticker': 'AAPL', 'name': '์• ํ”Œ', 'description': '์•„์ดํฐ์„'},
124
- "MSFT": {'ticker': 'MSFT', 'name': '๋งˆ์ดํฌ๋กœ์†Œํ”„ํŠธ', 'description': '์œˆ๋„์šฐ ์šด์˜์ฒด์ œ์™€ ์˜คํ”ผ์Šค ์†Œํ”„ํŠธ์›จ์–ด๋ฅผ'},
125
- "AMZN": {'ticker': 'AMZN', 'name': '์•„๋งˆ์กด', 'description': '์ „์ž์ƒ๊ฑฐ๋ž˜ ๋ฐ ํด๋ผ์šฐ๋“œ ์„œ๋น„์Šค๋ฅผ'},
126
- "GOOGL": {'ticker': 'GOOGL', 'name': '์•ŒํŒŒ๋ฒณ', 'description': '๊ฒ€์ƒ‰ ์—”์ง„ ๋ฐ ์˜จ๋ผ์ธ ๊ด‘๊ณ ๋ฅผ'},
127
- "TSLA": {'ticker': 'TSLA', 'name': 'ํ…Œ์Šฌ๋ผ', 'description': '์ „๊ธฐ์ž๋™์ฐจ์™€ ์—๋„ˆ์ง€ ์ €์žฅ์žฅ์น˜๋ฅผ'},
128
- }
129
- return stock_info.get(ticker, {'ticker': None, 'name': None, 'description': ''})
130
-
131
- # ์ข…๋ชฉ๋ช…์„ ํ† ๋Œ€๋กœ ํ‹ฐ์ปค์™€ ๊ธฐ์—… ์ •๋ณด๋ฅผ ์ œ๊ณตํ•˜๋Š” ํ•จ์ˆ˜
132
- def get_stock_info(name):
133
- stock_info = {
134
- "apple": {'ticker': 'AAPL', 'name': '์• ํ”Œ', 'description': '์•„์ดํฐ์„'},
135
- "microsoft": {'ticker': 'MSFT', 'name': '๋งˆ์ดํฌ๋กœ์†Œํ”„ํŠธ', 'description': '์œˆ๋„์šฐ ์šด์˜์ฒด์ œ์™€ ์˜คํ”ผ์Šค ์†Œํ”„ํŠธ์›จ์–ด๋ฅผ'},
136
- "amazon": {'ticker': 'AMZN', 'name': '์•„๋งˆ์กด', 'description': '์ „์ž์ƒ๊ฑฐ๋ž˜ ๋ฐ ํด๋ผ์šฐ๋“œ ์„œ๋น„์Šค๋ฅผ'},
137
- "google": {'ticker': 'GOOGL', 'name': '์•ŒํŒŒ๋ฒณ (๊ตฌ๊ธ€)', 'description': '๊ฒ€์ƒ‰ ์—”์ง„ ๋ฐ ์˜จ๋ผ์ธ ๊ด‘๊ณ ๋ฅผ'},
138
- "tesla": {'ticker': 'TSLA', 'name': 'ํ…Œ์Šฌ๋ผ', 'description': '์ „๊ธฐ์ž๋™์ฐจ์™€ ์—๋„ˆ์ง€ ์ €์žฅ์žฅ์น˜๋ฅผ'},
139
- # ์ถ”๊ฐ€์ ์ธ ์ข…๋ชฉ์— ๋Œ€ํ•œ ์ •๋ณด๋ฅผ ์ด๊ณณ์— ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
140
- }
141
- return stock_info.get(name.lower(), {'ticker': None, 'name': name, 'description': ''})
142
-
143
  mychatbot = gr.Chatbot(
144
  avatar_images=["./user.png", "./botm.png"],
145
  bubble_full_width=False,
 
2
  import gradio as gr
3
  from transformers import GPT2Tokenizer
4
  import yfinance as yf
5
+ import pandas as pd
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import talib
9
+ import tech_indicators as ti
10
 
11
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
12
  tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
 
54
 
55
  def get_stock_data(ticker):
56
  stock = yf.Ticker(ticker)
57
+ hist = stock.history(period="6mo") # ์ง€๋‚œ 6๊ฐœ์›”๊ฐ„์˜ ๋ฐ์ดํ„ฐ๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
58
  return hist
59
 
60
+ def apply_technical_indicators(df):
61
+ df['SMA'] = talib.SMA(df['Close'], timeperiod=20)
62
+ df['EMA'] = talib.EMA(df['Close'], timeperiod=20)
63
+ df['RSI'] = talib.RSI(df['Close'], timeperiod=14)
64
+ macd, macdsignal, macdhist = talib.MACD(df['Close'], fastperiod=12, slowperiod=26, signalperiod=9)
65
+ df['MACD'] = macd
66
+ df['MACD_signal'] = macdsignal
67
+ return df
68
+
69
+ def plot_technical_indicators(df):
70
+ plt.figure(figsize=(14, 7))
71
+ plt.subplot(2, 1, 1)
72
+ plt.plot(df['Close'], label='Close Price')
73
+ plt.plot(df['SMA'], label='SMA 20')
74
+ plt.plot(df['EMA'], label='EMA 20')
75
+ plt.title('Price Chart with SMA and EMA')
76
+ plt.legend()
77
+
78
+ plt.subplot(2, 1, 2)
79
+ plt.plot(df['RSI'], label='RSI')
80
+ plt.title('RSI Chart')
81
+ plt.legend()
82
+
83
+ plt.tight_layout()
84
+ plt.savefig('/mnt/data/Technical_Indicators.png')
85
+ plt.close()
86
+ return '/mnt/data/Technical_Indicators.png'
87
+
88
  def generate(prompt, history=[], temperature=0.1, max_new_tokens=10000, top_p=0.95, repetition_penalty=1.0):
89
  global total_tokens_used
90
  input_tokens = len(tokenizer.encode(prompt))
 
97
  formatted_prompt = format_prompt(prompt, history)
98
  output_accumulated = ""
99
  try:
100
+ ticker = prompt.upper()
101
+ stock_data = get_stock_data(ticker)
102
+ if not stock_data.empty:
103
+ enhanced_data = apply_technical_indicators(stock_data)
104
+ image_path = plot_technical_indicators(enhanced_data)
105
+ yield f"Technical analysis for {ticker} completed. See the chart here: {image_path}\n\n---\nTotal tokens used: {total_tokens_used}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  else:
107
+ yield f"No data available for {ticker}. Please check the ticker and try again."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
  yield f"Error: {str(e)}\nTotal tokens used: {total_tokens_used}"
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  mychatbot = gr.Chatbot(
112
  avatar_images=["./user.png", "./botm.png"],
113
  bubble_full_width=False,