parkerjj commited on
Commit
8ec911f
·
1 Parent(s): d48ef09

优化 Dockerfile 和 us_stock.py,增加 uvicorn 工作进程数,添加股票最新价格缓存功能,简化获取股票信息逻辑

Browse files
Files changed (4) hide show
  1. Dockerfile +1 -1
  2. blkeras.py +220 -205
  3. preprocess.py +26 -55
  4. us_stock.py +32 -2
Dockerfile CHANGED
@@ -44,4 +44,4 @@ RUN --mount=type=secret,id=HF_Token,mode=0444,required=true \
44
  # git clone $(cat /run/secrets/HF_Token)
45
 
46
 
47
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
44
  # git clone $(cat /run/secrets/HF_Token)
45
 
46
 
47
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "3"]
blkeras.py CHANGED
@@ -21,7 +21,7 @@ import os
21
 
22
  from RequestModel import PredictRequest
23
  from app import TextRequest
24
- from us_stock import find_stock_codes_or_names
25
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
26
  # 设置环境变量,指定 Hugging Face 缓存路径
27
  os.environ["HF_HOME"] = "/tmp/huggingface"
@@ -83,6 +83,14 @@ def generate_fake_accuracy():
83
  return round(fake_accuracy, 5)
84
 
85
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  def predict(text: str, stock_codes: list):
@@ -111,7 +119,7 @@ def predict(text: str, stock_codes: list):
111
  #print("Dependency Parsing:", dependency_parsing)
112
  #print("Sentiment Score:", sentiment_score)
113
 
114
- if affected_stock_codes is None:
115
  # 从 NER 结果中提取相关的股票代码或公司名称
116
  affected_stock_codes = find_stock_codes_or_names(ner)
117
 
@@ -119,268 +127,275 @@ def predict(text: str, stock_codes: list):
119
  cache_key = generate_key(lemmatized_entry)
120
  # 检查缓存中是否已有结果
121
  if cache_key in prediction_cache:
122
- print(f"Cache hit: {cache_key} lemmatized_entry: {lemmatized_entry}" )
123
  return prediction_cache[cache_key]
124
 
125
 
 
 
 
 
126
  # 调用 get_stock_info 函数
127
- previous_stock_history, _, previous_stock_inx_index_history, previous_stock_dj_index_history, previous_stock_ixic_index_history, previous_stock_ndx_index_history, _, _, _, _ = get_stock_info(affected_stock_codes)
128
 
 
 
129
 
130
- def ensure_fixed_shape(data, shape, variable_name=""):
131
- data = np.array(data)
132
- if data.shape != shape:
133
- fixed_data = np.full(shape, -1)
134
- min_shape = tuple(min(s1, s2) for s1, s2 in zip(data.shape, shape))
135
- fixed_data[:min_shape[0], :min_shape[1], :min_shape[2]] = data[:min_shape[0], :min_shape[1], :min_shape[2]]
136
- return fixed_data
137
- return data
138
 
139
- previous_stock_history = ensure_fixed_shape(previous_stock_history, (1, 30, 6), "previous_stock_history")
140
- previous_stock_inx_index_history = ensure_fixed_shape(previous_stock_inx_index_history, (1, 30, 6), "previous_stock_inx_index_history")
141
- previous_stock_dj_index_history = ensure_fixed_shape(previous_stock_dj_index_history, (1, 30, 6), "previous_stock_dj_index_history")
142
- previous_stock_ixic_index_history = ensure_fixed_shape(previous_stock_ixic_index_history, (1, 30, 6), "previous_stock_ixic_index_history")
143
- previous_stock_ndx_index_history = ensure_fixed_shape(previous_stock_ndx_index_history, (1, 30, 6), "previous_stock_ndx_index_history")
144
 
145
 
146
-
 
 
 
 
147
 
148
- # 3. 将特征转换为适合模型输入的形状
149
- # 这里假设文本、POS、实体识别等是向量,时间序列特征是 (sequence_length, feature_dim) 的形状
150
 
151
-
152
- # POS 和 NER 特征处理
153
- # 只取 POS Tagging 的第二部分(即 POS 标签的字母形式)进行处理
154
- pos_results = [process_pos_tags(pos_tag[1])[0]] # 传入 POS 标签列表
155
- ner_results = [process_entities(ner)[0]] # 假设是单个输入
156
 
 
 
 
 
 
157
 
158
- print("POS Results:", pos_results)
159
- print("NER Results:", ner_results)
160
 
161
- # 使用与模型定义一致的 pos_tag_dim 和 entity_dim
162
- pos_tag_dim = 1024 # 你需要根据模型定义来确定
163
- entity_dim = 1024 # 你需要根据模型定义来确定
164
 
165
- # 调整 max_length 为与 pos_tag_dim 和 entity_dim 一致的值
166
- X_pos_tags = pad_sequences(pos_results, maxlen=pos_tag_dim, padding='post', truncating='post', dtype='float32')
167
- X_entities = pad_sequences(ner_results, maxlen=entity_dim, padding='post', truncating='post', dtype='float32')
168
 
169
- # 确保形状为 (1, 1024)
170
- X_pos_tags = X_pos_tags.reshape(1, -1)
171
- X_entities = X_entities.reshape(1, -1)
172
 
173
- # Word2Vec 向量处理
174
- lemmatized_words = lemmatized_entry # 这里是 lemmatized_entry 的结果
175
- if not lemmatized_words:
176
- raise ValueError("Lemmatized words are empty.")
177
 
178
- X_word2vec = np.array([get_document_vector(lemmatized_words)], dtype='float32') # 使用 get_document_vector 将 lemmatized_words 转为向量
 
 
 
179
 
180
- # 情感得分
181
- X_sentiment = np.array([[sentiment_score]], dtype='float32') # sentiment_score 已经是单值,直接转换为二维数组
182
 
183
- # 打印输入特征的形状,便于调试
184
- # print("X_word2vec shape:", X_word2vec.shape)
185
- # print("X_pos_tags shape:", X_pos_tags.shape)
186
- # print("X_entities shape:", X_entities.shape)
187
- # print("X_sentiment shape:", X_sentiment.shape)
188
 
 
 
 
 
 
189
 
190
 
191
- # 静态特征
192
- X_word2vec = ensure_fixed_shape(X_word2vec, (1, 300), "X_word2vec")
193
- X_pos_tags = ensure_fixed_shape(X_pos_tags, (1, 1024), "X_pos_tags")
194
- X_entities = ensure_fixed_shape(X_entities, (1, 1024), "X_entities")
195
- X_sentiment = ensure_fixed_shape(X_sentiment, (1, 1), "X_sentiment")
196
 
 
 
 
 
 
197
 
198
 
199
- features = [
200
- X_word2vec, X_pos_tags, X_entities, X_sentiment,
201
- previous_stock_inx_index_history, previous_stock_dj_index_history,
202
- previous_stock_ixic_index_history, previous_stock_ndx_index_history,
203
- previous_stock_history
204
- ]
205
 
 
 
 
 
 
 
206
 
207
 
208
- # 打印特征数组的每个元素的形状,便于调试
209
- # for i, feature in enumerate(features):
210
- # print(f"Feature {i} shape: {feature.shape} value: {feature[0]} length: {len(feature[0])}")
211
- # for name, feature in enumerate(features):
212
- # print(f"模型输入数据 {name} shape: {feature.shape}")
213
 
214
- # for layer in model.input:
215
- # print(f"模型所需的输入层 {layer.name}, 形状: {layer.shape}")
 
 
 
216
 
217
- # 使用模型进行预测
218
- predictions = model.predict(features)
219
 
220
- # 生成伪精准度值
221
- fake_accuracy = generate_fake_accuracy()
222
 
223
- # 将 predictions 中的每个数组转换为 Python 列表
224
- index_inx_predictions = predictions[0].tolist()
225
- index_dj_predictions = predictions[1].tolist()
226
- index_ixic_predictions = predictions[2].tolist()
227
- index_ndx_predictions = predictions[3].tolist()
228
- stock_predictions = predictions[4].tolist()
229
 
230
- # 打印预测结果,便于调试
231
- #print("Index INX Predictions:", index_inx_predictions)
232
- #print("Index DJ Predictions:", index_dj_predictions)
233
- #print("Index IXIC Predictions:", index_ixic_predictions)
234
- #print("Index NDX Predictions:", index_ndx_predictions)
235
- #print("Stock Predictions:", stock_predictions)
236
 
 
 
 
 
 
 
 
237
 
238
- # 获取 index_feature 中最后一天的第一个值
239
- last_index_inx_value = previous_stock_inx_index_history[0][-1][0]
240
- last_index_dj_value = previous_stock_dj_index_history[0][-1][0]
241
- last_index_ixic_value = previous_stock_ixic_index_history[0][-1][0]
242
- last_index_ndx_value = previous_stock_ndx_index_history[0][-1][0]
243
- last_stock_value = previous_stock_history[0][-1][0]
244
 
 
 
 
 
 
 
245
 
 
 
246
 
247
- # 针对 1012 模型的修复
248
- stock_predictions = stock_fix_for_1118_model(float(X_sentiment[0][0]), stock_predictions[0], last_stock_value, is_index=False)
249
- index_inx_predictions = stock_fix_for_1118_model(float(X_sentiment[0][0]), index_inx_predictions[0], last_index_inx_value)
250
- index_dj_predictions = stock_fix_for_1118_model(float(X_sentiment[0][0]), index_dj_predictions[0], last_index_dj_value)
251
- index_ixic_predictions = stock_fix_for_1118_model(float(X_sentiment[0][0]), index_ixic_predictions[0], last_index_ixic_value)
252
- index_ndx_predictions = stock_fix_for_1118_model(float(X_sentiment[0][0]), index_ndx_predictions[0], last_index_ndx_value)
253
 
254
- #print("Stock Predictions after fix:", stock_predictions)
255
- #print("Index INX Predictions after fix:", index_inx_predictions)
256
- #print("Index DJ Predictions after fix:", index_dj_predictions)
257
- #print("Index IXIC Predictions after fix:", index_ixic_predictions)
258
- #print("Index NDX Predictions after fix:", index_ndx_predictions)
259
 
 
 
260
 
 
 
261
 
262
- # 提取 Index Predictions 中每一天的第一个值
263
- index_inx_day_1 = index_inx_predictions[0][0]
264
- index_inx_day_2 = index_inx_predictions[1][0]
265
- index_inx_day_3 = index_inx_predictions[2][0]
266
 
267
- index_dj_day_1 = index_dj_predictions[0][0]
268
- index_dj_day_2 = index_dj_predictions[1][0]
269
- index_dj_day_3 = index_dj_predictions[2][0]
270
 
271
- index_ixic_day_1 = index_ixic_predictions[0][0]
272
- index_ixic_day_2 = index_ixic_predictions[1][0]
273
- index_ixic_day_3 = index_ixic_predictions[2][0]
 
 
 
274
 
275
- index_ndx_day_1 = index_ndx_predictions[0][0]
276
- index_ndx_day_2 = index_ndx_predictions[1][0]
277
- index_ndx_day_3 = index_ndx_predictions[2][0]
 
 
278
 
279
- stock_day_1 = stock_predictions[0][0]
280
- stock_day_2 = stock_predictions[1][0]
281
- stock_day_3 = stock_predictions[2][0]
282
 
283
- # 计算 impact_1_day, impact_2_day, impact_3_day
284
- impact_inx_1_day = (index_inx_day_1 - last_index_inx_value) / last_index_inx_value if last_index_inx_value != 0 else 0
285
- impact_inx_2_day = (index_inx_day_2 - index_inx_day_1) / index_inx_day_1 if index_inx_day_1 != 0 else 0
286
- impact_inx_3_day = (index_inx_day_3 - index_inx_day_2) / index_inx_day_2 if index_inx_day_2 != 0 else 0
287
 
288
- impact_dj_1_day = (index_dj_day_1 - last_index_dj_value) / last_index_dj_value if last_index_dj_value != 0 else 0
289
- impact_dj_2_day = (index_dj_day_2 - index_dj_day_1) / index_dj_day_1 if index_dj_day_1 != 0 else 0
290
- impact_dj_3_day = (index_dj_day_3 - index_dj_day_2) / index_dj_day_2 if index_dj_day_2 != 0 else 0
291
-
292
- impact_ixic_1_day = (index_ixic_day_1 - last_index_ixic_value) / last_index_ixic_value if last_index_ixic_value != 0 else 0
293
- impact_ixic_2_day = (index_ixic_day_2 - index_ixic_day_1) / index_ixic_day_1 if index_ixic_day_1 != 0 else 0
294
- impact_ixic_3_day = (index_ixic_day_3 - index_ixic_day_2) / index_ixic_day_2 if index_ixic_day_2 != 0 else 0
295
 
296
- impact_ndx_1_day = (index_ndx_day_1 - last_index_ndx_value) / last_index_ndx_value if last_index_ndx_value != 0 else 0
297
- impact_ndx_2_day = (index_ndx_day_2 - index_ndx_day_1) / index_ndx_day_1 if index_ndx_day_1 != 0 else 0
298
- impact_ndx_3_day = (index_ndx_day_3 - index_ndx_day_2) / index_ndx_day_2 if index_ndx_day_2 != 0 else 0
299
-
300
- impact_stock_1_day = (stock_day_1 - last_stock_value) / last_stock_value if last_stock_value != 0 else 0
301
- impact_stock_2_day = (stock_day_2 - stock_day_1) / stock_day_1 if stock_day_1 != 0 else 0
302
- impact_stock_3_day = (stock_day_3 - stock_day_2) / stock_day_2 if stock_day_2 != 0 else 0
303
-
304
- # impact 值转换为百分比字符串
305
- impact_inx_1_day_str = f"{impact_inx_1_day:.2%}"
306
- impact_inx_2_day_str = f"{impact_inx_2_day:.2%}"
307
- impact_inx_3_day_str = f"{impact_inx_3_day:.2%}"
308
-
309
- impact_dj_1_day_str = f"{impact_dj_1_day:.2%}"
310
- impact_dj_2_day_str = f"{impact_dj_2_day:.2%}"
311
- impact_dj_3_day_str = f"{impact_dj_3_day:.2%}"
312
-
313
- impact_ixic_1_day_str = f"{impact_ixic_1_day:.2%}"
314
- impact_ixic_2_day_str = f"{impact_ixic_2_day:.2%}"
315
- impact_ixic_3_day_str = f"{impact_ixic_3_day:.2%}"
316
-
317
- impact_ndx_1_day_str = f"{impact_ndx_1_day:.2%}"
318
- impact_ndx_2_day_str = f"{impact_ndx_2_day:.2%}"
319
- impact_ndx_3_day_str = f"{impact_ndx_3_day:.2%}"
320
-
321
- impact_stock_1_day_str = f"{impact_stock_1_day:.2%}"
322
- impact_stock_2_day_str = f"{impact_stock_2_day:.2%}"
323
- impact_stock_3_day_str = f"{impact_stock_3_day:.2%}"
324
-
325
-
326
- # 如果需要返回原始预测数据进行调试,可以直接将其放到响应中
327
- if len(affected_stock_codes) > 5:
328
- affected_stock_codes_str = "/".join(affected_stock_codes[:3]) + f" and {len(affected_stock_codes)} other stocks"
329
- else:
330
- affected_stock_codes_str = "/".join(affected_stock_codes) if affected_stock_codes else "N/A"
331
-
332
-
333
-
334
-
335
- # 扩展股票预测数据到分钟级别
336
- stock_predictions = extend_stock_days_to_mins(stock_predictions)
337
- index_inx_predictions = extend_stock_days_to_mins(index_inx_predictions)
338
- index_dj_predictions = extend_stock_days_to_mins(index_dj_predictions)
339
- index_ixic_predictions = extend_stock_days_to_mins(index_ixic_predictions)
340
- index_ndx_predictions = extend_stock_days_to_mins(index_ndx_predictions)
341
-
342
-
343
-
344
- # 如果需要返回原始预测数据进行调试,可以直接将其放到响应中
345
- result = {
346
- "news_title": input_text,
347
- "ai_prediction_score": float(X_sentiment[0][0]), # 假设第一个预测值是 AI 预测得分
348
- "impact_inx_1_day": impact_inx_1_day_str, # 计算并格式化 impact_1_day
349
- "impact_inx_2_day": impact_inx_2_day_str, # 计算并格式化 impact_2_day
350
- "impact_inx_3_day": impact_inx_3_day_str,
351
- "impact_dj_1_day": impact_dj_1_day_str, # 计算并格式化 impact_1_day
352
- "impact_dj_2_day": impact_dj_2_day_str, # 计算并格式化 impact_2_day
353
- "impact_dj_3_day": impact_dj_3_day_str,
354
- "impact_ixic_1_day": impact_ixic_1_day_str, # 计算并格式化 impact_1_day
355
- "impact_ixic_2_day": impact_ixic_2_day_str, # 计算并格式化 impact_2_day
356
- "impact_ixic_3_day": impact_ixic_3_day_str,
357
- "impact_ndx_1_day": impact_ndx_1_day_str, # 计算并格式化 impact_1_day
358
- "impact_ndx_2_day": impact_ndx_2_day_str, # 计算并格式化 impact_2_day
359
- "impact_ndx_3_day": impact_ndx_3_day_str,
360
- "impact_stock_1_day": impact_stock_1_day_str, # 计算并格式化 impact_1_day
361
- "impact_stock_2_day": impact_stock_2_day_str, # 计算并格式化 impact_2_day
362
- "impact_stock_3_day": impact_stock_3_day_str,
363
- "affected_stock_codes": affected_stock_codes_str, # 动态生成受影响的股票代码
364
- "accuracy": float(fake_accuracy),
365
- "impact_on_stock": stock_predictions, # 第一个预测值是股票影响
366
- "impact_on_index_inx": index_inx_predictions, # 第一个预测值是股票影响
367
- "impact_on_index_dj": index_dj_predictions, # 第一个预测值是股票影响
368
- "impact_on_index_ixic": index_ixic_predictions, # 第一个预测值是股票影响
369
- "impact_on_index_ndx": index_ndx_predictions, # 第一个预测值是股票影响
370
-
371
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
  # 缓存预测结果
374
- prediction_cache[cache_key] = result
375
 
376
  # 如果缓存大小超过最大限制,移除最早的缓存项
377
  if len(prediction_cache) > CACHE_MAX_SIZE:
378
  prediction_cache.popitem(last=False)
379
 
380
- #print(f"predict() result: {result}")
381
 
382
  # 返回预测结果
383
- return result
384
 
385
  except Exception as e:
386
  # 打印完整的错误堆栈信息
 
21
 
22
  from RequestModel import PredictRequest
23
  from app import TextRequest
24
+ from us_stock import find_stock_codes_or_names, get_last_minute_stock_price
25
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
26
  # 设置环境变量,指定 Hugging Face 缓存路径
27
  os.environ["HF_HOME"] = "/tmp/huggingface"
 
83
  return round(fake_accuracy, 5)
84
 
85
 
86
+ def ensure_fixed_shape(data, shape, variable_name=""):
87
+ data = np.array(data)
88
+ if data.shape != shape:
89
+ fixed_data = np.full(shape, -1)
90
+ min_shape = tuple(min(s1, s2) for s1, s2 in zip(data.shape, shape))
91
+ fixed_data[:min_shape[0], :min_shape[1], :min_shape[2]] = data[:min_shape[0], :min_shape[1], :min_shape[2]]
92
+ return fixed_data
93
+ return data
94
 
95
 
96
  def predict(text: str, stock_codes: list):
 
119
  #print("Dependency Parsing:", dependency_parsing)
120
  #print("Sentiment Score:", sentiment_score)
121
 
122
+ if affected_stock_codes is None or not affected_stock_codes:
123
  # 从 NER 结果中提取相关的股票代码或公司名称
124
  affected_stock_codes = find_stock_codes_or_names(ner)
125
 
 
127
  cache_key = generate_key(lemmatized_entry)
128
  # 检查缓存中是否已有结果
129
  if cache_key in prediction_cache:
130
+ print(f"Cache hit: {cache_key}" )
131
  return prediction_cache[cache_key]
132
 
133
 
134
+
135
+ # Final Result
136
+ final_result_list = []
137
+
138
  # 调用 get_stock_info 函数
 
139
 
140
+ for stock_code in affected_stock_codes:
141
+ previous_stock_history, _, previous_stock_inx_index_history, previous_stock_dj_index_history, previous_stock_ixic_index_history, previous_stock_ndx_index_history, _, _, _, _ = get_stock_info(stock_code)
142
 
 
 
 
 
 
 
 
 
143
 
 
 
 
 
 
144
 
145
 
146
+ previous_stock_history = ensure_fixed_shape(previous_stock_history, (1, 30, 6), "previous_stock_history")
147
+ previous_stock_inx_index_history = ensure_fixed_shape(previous_stock_inx_index_history, (1, 30, 6), "previous_stock_inx_index_history")
148
+ previous_stock_dj_index_history = ensure_fixed_shape(previous_stock_dj_index_history, (1, 30, 6), "previous_stock_dj_index_history")
149
+ previous_stock_ixic_index_history = ensure_fixed_shape(previous_stock_ixic_index_history, (1, 30, 6), "previous_stock_ixic_index_history")
150
+ previous_stock_ndx_index_history = ensure_fixed_shape(previous_stock_ndx_index_history, (1, 30, 6), "previous_stock_ndx_index_history")
151
 
 
 
152
 
153
+
154
+
155
+ # 3. 将特征转换为适合模型输入的形状
156
+ # 这里假设文本、POS、实体识别等是向量,时间序列特征是 (sequence_length, feature_dim) 的形状
 
157
 
158
+
159
+ # POS 和 NER 特征处理
160
+ # 只取 POS Tagging 的第二部分(即 POS 标签的字母形式)进行处理
161
+ pos_results = [process_pos_tags(pos_tag[1])[0]] # 传入 POS 标签列表
162
+ ner_results = [process_entities(ner)[0]] # 假设是单个输入
163
 
 
 
164
 
165
+ #print("POS Results:", pos_results)
166
+ #print("NER Results:", ner_results)
 
167
 
168
+ # 使用与模型定义一致的 pos_tag_dim 和 entity_dim
169
+ pos_tag_dim = 1024 # 你需要根据模型定义来确定
170
+ entity_dim = 1024 # 你需要根据模型定义来确定
171
 
172
+ # 调整 max_length 为与 pos_tag_dim 和 entity_dim 一致的值
173
+ X_pos_tags = pad_sequences(pos_results, maxlen=pos_tag_dim, padding='post', truncating='post', dtype='float32')
174
+ X_entities = pad_sequences(ner_results, maxlen=entity_dim, padding='post', truncating='post', dtype='float32')
175
 
176
+ # 确保形状为 (1, 1024)
177
+ X_pos_tags = X_pos_tags.reshape(1, -1)
178
+ X_entities = X_entities.reshape(1, -1)
 
179
 
180
+ # Word2Vec 向量处理
181
+ lemmatized_words = lemmatized_entry # 这里是 lemmatized_entry 的结果
182
+ if not lemmatized_words:
183
+ raise ValueError("Lemmatized words are empty.")
184
 
185
+ X_word2vec = np.array([get_document_vector(lemmatized_words)], dtype='float32') # 使用 get_document_vector 将 lemmatized_words 转为向量
 
186
 
187
+ # 情感得分
188
+ X_sentiment = np.array([[sentiment_score]], dtype='float32') # sentiment_score 已经是单值,直接转换为二维数组
 
 
 
189
 
190
+ # 打印输入特征的形状,便于调试
191
+ # print("X_word2vec shape:", X_word2vec.shape)
192
+ # print("X_pos_tags shape:", X_pos_tags.shape)
193
+ # print("X_entities shape:", X_entities.shape)
194
+ # print("X_sentiment shape:", X_sentiment.shape)
195
 
196
 
 
 
 
 
 
197
 
198
+ # 静态特征
199
+ X_word2vec = ensure_fixed_shape(X_word2vec, (1, 300), "X_word2vec")
200
+ X_pos_tags = ensure_fixed_shape(X_pos_tags, (1, 1024), "X_pos_tags")
201
+ X_entities = ensure_fixed_shape(X_entities, (1, 1024), "X_entities")
202
+ X_sentiment = ensure_fixed_shape(X_sentiment, (1, 1), "X_sentiment")
203
 
204
 
 
 
 
 
 
 
205
 
206
+ features = [
207
+ X_word2vec, X_pos_tags, X_entities, X_sentiment,
208
+ previous_stock_inx_index_history, previous_stock_dj_index_history,
209
+ previous_stock_ixic_index_history, previous_stock_ndx_index_history,
210
+ previous_stock_history
211
+ ]
212
 
213
 
 
 
 
 
 
214
 
215
+ # 打印特征数组的每个元素的形状,便于调试
216
+ # for i, feature in enumerate(features):
217
+ # print(f"Feature {i} shape: {feature.shape} value: {feature[0]} length: {len(feature[0])}")
218
+ # for name, feature in enumerate(features):
219
+ # print(f"模型输入数据 {name} shape: {feature.shape}")
220
 
221
+ # for layer in model.input:
222
+ # print(f"模型所需的输入层 {layer.name}, 形状: {layer.shape}")
223
 
224
+ # 使用模型进行预测
225
+ predictions = model.predict(features)
226
 
227
+ # 生成伪精准度值
228
+ fake_accuracy = generate_fake_accuracy()
 
 
 
 
229
 
230
+ # 将 predictions 中的每个数组转换为 Python 列表
231
+ index_inx_predictions = predictions[0].tolist()
232
+ index_dj_predictions = predictions[1].tolist()
233
+ index_ixic_predictions = predictions[2].tolist()
234
+ index_ndx_predictions = predictions[3].tolist()
235
+ stock_predictions = predictions[4].tolist()
236
 
237
+ # 打印预测结果,便于调试
238
+ #print("Index INX Predictions:", index_inx_predictions)
239
+ #print("Index DJ Predictions:", index_dj_predictions)
240
+ #print("Index IXIC Predictions:", index_ixic_predictions)
241
+ #print("Index NDX Predictions:", index_ndx_predictions)
242
+ #print("Stock Predictions:", stock_predictions)
243
+
244
 
 
 
 
 
 
 
245
 
246
+ # 获取 index_feature 中最后一天的第一个值
247
+ last_index_inx_value = get_last_minute_stock_price('^GSPC')
248
+ last_index_dj_value = get_last_minute_stock_price('^DJI')
249
+ last_index_ixic_value = get_last_minute_stock_price('^IXIC')
250
+ last_index_ndx_value = get_last_minute_stock_price('^NDX')
251
+ last_stock_value = get_last_minute_stock_price(stock_code)
252
 
253
+ if last_index_inx_value <= 0:
254
+ last_index_inx_value = previous_stock_inx_index_history[0][-1][0]
255
 
256
+ if last_index_dj_value <= 0:
257
+ last_index_dj_value = previous_stock_dj_index_history[0][-1][0]
 
 
 
 
258
 
259
+ if last_index_ixic_value <= 0:
260
+ last_index_ixic_value = previous_stock_ixic_index_history[0][-1][0]
 
 
 
261
 
262
+ if last_index_ndx_value <= 0:
263
+ last_index_ndx_value = previous_stock_ndx_index_history[0][-1][0]
264
 
265
+ if last_stock_value <= 0:
266
+ last_stock_value = previous_stock_history[0][-1][0]
267
 
 
 
 
 
268
 
 
 
 
269
 
270
+ # 针对 1012 模型的修复
271
+ stock_predictions = stock_fix_for_1118_model(float(X_sentiment[0][0]), stock_predictions[0], last_stock_value, is_index=False)
272
+ index_inx_predictions = stock_fix_for_1118_model(float(X_sentiment[0][0]), index_inx_predictions[0], last_index_inx_value)
273
+ index_dj_predictions = stock_fix_for_1118_model(float(X_sentiment[0][0]), index_dj_predictions[0], last_index_dj_value)
274
+ index_ixic_predictions = stock_fix_for_1118_model(float(X_sentiment[0][0]), index_ixic_predictions[0], last_index_ixic_value)
275
+ index_ndx_predictions = stock_fix_for_1118_model(float(X_sentiment[0][0]), index_ndx_predictions[0], last_index_ndx_value)
276
 
277
+ #print("Stock Predictions after fix:", stock_predictions)
278
+ #print("Index INX Predictions after fix:", index_inx_predictions)
279
+ #print("Index DJ Predictions after fix:", index_dj_predictions)
280
+ #print("Index IXIC Predictions after fix:", index_ixic_predictions)
281
+ #print("Index NDX Predictions after fix:", index_ndx_predictions)
282
 
 
 
 
283
 
 
 
 
 
284
 
285
+ # 提取 Index Predictions 中每一天的第一个值
286
+ index_inx_day_1 = index_inx_predictions[0][0]
287
+ index_inx_day_2 = index_inx_predictions[1][0]
288
+ index_inx_day_3 = index_inx_predictions[2][0]
 
 
 
289
 
290
+ index_dj_day_1 = index_dj_predictions[0][0]
291
+ index_dj_day_2 = index_dj_predictions[1][0]
292
+ index_dj_day_3 = index_dj_predictions[2][0]
293
+
294
+ index_ixic_day_1 = index_ixic_predictions[0][0]
295
+ index_ixic_day_2 = index_ixic_predictions[1][0]
296
+ index_ixic_day_3 = index_ixic_predictions[2][0]
297
+
298
+ index_ndx_day_1 = index_ndx_predictions[0][0]
299
+ index_ndx_day_2 = index_ndx_predictions[1][0]
300
+ index_ndx_day_3 = index_ndx_predictions[2][0]
301
+
302
+ stock_day_1 = stock_predictions[0][0]
303
+ stock_day_2 = stock_predictions[1][0]
304
+ stock_day_3 = stock_predictions[2][0]
305
+
306
+ # 计算 impact_1_day, impact_2_day, impact_3_day
307
+ impact_inx_1_day = (index_inx_day_1 - last_index_inx_value) / last_index_inx_value if last_index_inx_value != 0 else 0
308
+ impact_inx_2_day = (index_inx_day_2 - index_inx_day_1) / index_inx_day_1 if index_inx_day_1 != 0 else 0
309
+ impact_inx_3_day = (index_inx_day_3 - index_inx_day_2) / index_inx_day_2 if index_inx_day_2 != 0 else 0
310
+
311
+ impact_dj_1_day = (index_dj_day_1 - last_index_dj_value) / last_index_dj_value if last_index_dj_value != 0 else 0
312
+ impact_dj_2_day = (index_dj_day_2 - index_dj_day_1) / index_dj_day_1 if index_dj_day_1 != 0 else 0
313
+ impact_dj_3_day = (index_dj_day_3 - index_dj_day_2) / index_dj_day_2 if index_dj_day_2 != 0 else 0
314
+
315
+ impact_ixic_1_day = (index_ixic_day_1 - last_index_ixic_value) / last_index_ixic_value if last_index_ixic_value != 0 else 0
316
+ impact_ixic_2_day = (index_ixic_day_2 - index_ixic_day_1) / index_ixic_day_1 if index_ixic_day_1 != 0 else 0
317
+ impact_ixic_3_day = (index_ixic_day_3 - index_ixic_day_2) / index_ixic_day_2 if index_ixic_day_2 != 0 else 0
318
+
319
+ impact_ndx_1_day = (index_ndx_day_1 - last_index_ndx_value) / last_index_ndx_value if last_index_ndx_value != 0 else 0
320
+ impact_ndx_2_day = (index_ndx_day_2 - index_ndx_day_1) / index_ndx_day_1 if index_ndx_day_1 != 0 else 0
321
+ impact_ndx_3_day = (index_ndx_day_3 - index_ndx_day_2) / index_ndx_day_2 if index_ndx_day_2 != 0 else 0
322
+
323
+ impact_stock_1_day = (stock_day_1 - last_stock_value) / last_stock_value if last_stock_value != 0 else 0
324
+ impact_stock_2_day = (stock_day_2 - stock_day_1) / stock_day_1 if stock_day_1 != 0 else 0
325
+ impact_stock_3_day = (stock_day_3 - stock_day_2) / stock_day_2 if stock_day_2 != 0 else 0
326
+
327
+ # 将 impact 值转换为百分比字符串
328
+ impact_inx_1_day_str = f"{impact_inx_1_day:.2%}"
329
+ impact_inx_2_day_str = f"{impact_inx_2_day:.2%}"
330
+ impact_inx_3_day_str = f"{impact_inx_3_day:.2%}"
331
+
332
+ impact_dj_1_day_str = f"{impact_dj_1_day:.2%}"
333
+ impact_dj_2_day_str = f"{impact_dj_2_day:.2%}"
334
+ impact_dj_3_day_str = f"{impact_dj_3_day:.2%}"
335
+
336
+ impact_ixic_1_day_str = f"{impact_ixic_1_day:.2%}"
337
+ impact_ixic_2_day_str = f"{impact_ixic_2_day:.2%}"
338
+ impact_ixic_3_day_str = f"{impact_ixic_3_day:.2%}"
339
+
340
+ impact_ndx_1_day_str = f"{impact_ndx_1_day:.2%}"
341
+ impact_ndx_2_day_str = f"{impact_ndx_2_day:.2%}"
342
+ impact_ndx_3_day_str = f"{impact_ndx_3_day:.2%}"
343
+
344
+ impact_stock_1_day_str = f"{impact_stock_1_day:.2%}"
345
+ impact_stock_2_day_str = f"{impact_stock_2_day:.2%}"
346
+ impact_stock_3_day_str = f"{impact_stock_3_day:.2%}"
347
+
348
+
349
+
350
+
351
+ # 扩展股票预测数据到分钟级别
352
+ stock_predictions = extend_stock_days_to_mins(stock_predictions)
353
+ index_inx_predictions = extend_stock_days_to_mins(index_inx_predictions)
354
+ index_dj_predictions = extend_stock_days_to_mins(index_dj_predictions)
355
+ index_ixic_predictions = extend_stock_days_to_mins(index_ixic_predictions)
356
+ index_ndx_predictions = extend_stock_days_to_mins(index_ndx_predictions)
357
+
358
+
359
+
360
+ # 如果需要返回原始预测数据进行调试,可以直接将其放到响应中
361
+ result = {
362
+ "news_title": input_text,
363
+ "ai_prediction_score": float(X_sentiment[0][0]), # 假设第一个预测值是 AI 预测得分
364
+ "impact_inx_1_day": impact_inx_1_day_str, # 计算并格式化 impact_1_day
365
+ "impact_inx_2_day": impact_inx_2_day_str, # 计算并格式化 impact_2_day
366
+ "impact_inx_3_day": impact_inx_3_day_str,
367
+ "impact_dj_1_day": impact_dj_1_day_str, # 计算并格式化 impact_1_day
368
+ "impact_dj_2_day": impact_dj_2_day_str, # 计算并格式化 impact_2_day
369
+ "impact_dj_3_day": impact_dj_3_day_str,
370
+ "impact_ixic_1_day": impact_ixic_1_day_str, # 计算并格式化 impact_1_day
371
+ "impact_ixic_2_day": impact_ixic_2_day_str, # 计算并格式化 impact_2_day
372
+ "impact_ixic_3_day": impact_ixic_3_day_str,
373
+ "impact_ndx_1_day": impact_ndx_1_day_str, # 计算并格式化 impact_1_day
374
+ "impact_ndx_2_day": impact_ndx_2_day_str, # 计算并格式化 impact_2_day
375
+ "impact_ndx_3_day": impact_ndx_3_day_str,
376
+ "impact_stock_1_day": impact_stock_1_day_str, # 计算并格式化 impact_1_day
377
+ "impact_stock_2_day": impact_stock_2_day_str, # 计算并格式化 impact_2_day
378
+ "impact_stock_3_day": impact_stock_3_day_str,
379
+ "affected_stock_codes": stock_code, # 动态生成受影响的股票代码
380
+ "accuracy": float(fake_accuracy),
381
+ "impact_on_stock": stock_predictions, # 第一个预测值是股票影响
382
+ "impact_on_index_inx": index_inx_predictions, # 第一个预测值是股票影响
383
+ "impact_on_index_dj": index_dj_predictions, # 第一个预测值是股票影响
384
+ "impact_on_index_ixic": index_ixic_predictions, # 第一个预测值是股票影响
385
+ "impact_on_index_ndx": index_ndx_predictions, # 第一个预测值是股票影响
386
+ }
387
+ final_result_list.append(result)
388
 
389
  # 缓存预测结果
390
+ prediction_cache[cache_key] = final_result_list
391
 
392
  # 如果缓存大小超过最大限制,移除最早的缓存项
393
  if len(prediction_cache) > CACHE_MAX_SIZE:
394
  prediction_cache.popitem(last=False)
395
 
 
396
 
397
  # 返回预测结果
398
+ return final_result_list
399
 
400
  except Exception as e:
401
  # 打印完整的错误堆栈信息
preprocess.py CHANGED
@@ -222,9 +222,8 @@ def get_sentiment_score(text):
222
 
223
 
224
 
225
- def get_stock_info(stock_codes, history_days=30):
226
  # 获取股票代码和新闻日期
227
- stock_codes = stock_codes
228
 
229
  news_date = datetime.now().strftime('%Y%m%d')
230
  # print(f"Getting stock info for {stock_codes} on {news_date}")
@@ -314,70 +313,42 @@ def get_stock_info(stock_codes, history_days=30):
314
 
315
  return previous_rows, following_rows
316
 
317
- if not stock_codes or stock_codes == ['']:
318
- # 如果 stock_codes 为空,直接获取并返回大盘数据
319
- stock_index_ndx_history = get_stock_index_history("", news_date, 1)
320
- stock_index_dj_history = get_stock_index_history("", news_date, 2)
321
- stock_index_inx_history = get_stock_index_history("", news_date, 3)
322
- stock_index_ixic_history = get_stock_index_history("", news_date, 4)
323
 
324
- previous_ndx_rows, following_ndx_rows = process_history(stock_index_ndx_history, news_date, history_days)
325
- previous_dj_rows, following_dj_rows = process_history(stock_index_dj_history, news_date, history_days)
326
- previous_inx_rows, following_inx_rows = process_history(stock_index_inx_history, news_date, history_days)
327
- previous_ixic_rows, following_ixic_rows = process_history(stock_index_ixic_history, news_date, history_days)
328
 
329
 
330
- previous_stock_inx_index_history.append(previous_inx_rows.values.tolist())
331
- previous_stock_dj_index_history.append(previous_dj_rows.values.tolist())
332
- previous_stock_ixic_index_history.append(previous_ixic_rows.values.tolist())
333
- previous_stock_ndx_index_history.append(previous_ndx_rows.values.tolist())
334
 
335
- following_stock_inx_index_history.append(following_inx_rows.values.tolist())
336
- following_stock_dj_index_history.append(following_dj_rows.values.tolist())
337
- following_stock_ixic_index_history.append(following_ixic_rows.values.tolist())
338
- following_stock_ndx_index_history.append(following_ndx_rows.values.tolist())
339
 
340
 
 
341
  # 个股补零逻辑
342
  previous_stock_history.append([[-1] * 6] * history_days)
343
  following_stock_history.append([[-1] * 6] * 3)
344
 
345
-
346
-
347
  else:
348
- for stock_code in stock_codes:
349
- stock_code = stock_code.strip()
350
- stock_history = get_stock_history(stock_code, news_date)
351
-
352
- # 处理个股数据
353
- previous_rows, following_rows = process_history(stock_history, news_date)
354
- previous_stock_history.append(previous_rows.values.tolist())
355
- following_stock_history.append(following_rows.values.tolist())
356
-
357
- # 处理大盘数据
358
- stock_index_ndx_history = get_stock_index_history("", news_date, 1)
359
- stock_index_dj_history = get_stock_index_history("", news_date, 2)
360
- stock_index_inx_history = get_stock_index_history("", news_date, 3)
361
- stock_index_ixic_history = get_stock_index_history("", news_date, 4)
362
-
363
- previous_ndx_rows, following_ndx_rows = process_history(stock_index_ndx_history, news_date, history_days)
364
- previous_dj_rows, following_dj_rows = process_history(stock_index_dj_history, news_date, history_days)
365
- previous_inx_rows, following_inx_rows = process_history(stock_index_inx_history, news_date, history_days)
366
- previous_ixic_rows, following_ixic_rows = process_history(stock_index_ixic_history, news_date, history_days)
367
-
368
-
369
- previous_stock_inx_index_history.append(previous_inx_rows.values.tolist())
370
- previous_stock_dj_index_history.append(previous_dj_rows.values.tolist())
371
- previous_stock_ixic_index_history.append(previous_ixic_rows.values.tolist())
372
- previous_stock_ndx_index_history.append(previous_ndx_rows.values.tolist())
373
-
374
- following_stock_inx_index_history.append(following_inx_rows.values.tolist())
375
- following_stock_dj_index_history.append(following_dj_rows.values.tolist())
376
- following_stock_ixic_index_history.append(following_ixic_rows.values.tolist())
377
- following_stock_ndx_index_history.append(following_ndx_rows.values.tolist())
378
-
379
- # 只返回第一支股票的数据
380
- break
381
 
382
  return previous_stock_history, following_stock_history, \
383
  previous_stock_inx_index_history, previous_stock_dj_index_history, previous_stock_ixic_index_history, previous_stock_ndx_index_history, \
 
222
 
223
 
224
 
225
+ def get_stock_info(stock_code: str, history_days=30):
226
  # 获取股票代码和新闻日期
 
227
 
228
  news_date = datetime.now().strftime('%Y%m%d')
229
  # print(f"Getting stock info for {stock_codes} on {news_date}")
 
313
 
314
  return previous_rows, following_rows
315
 
316
+ stock_index_ndx_history = get_stock_index_history("", news_date, 1)
317
+ stock_index_dj_history = get_stock_index_history("", news_date, 2)
318
+ stock_index_inx_history = get_stock_index_history("", news_date, 3)
319
+ stock_index_ixic_history = get_stock_index_history("", news_date, 4)
 
 
320
 
321
+ previous_ndx_rows, following_ndx_rows = process_history(stock_index_ndx_history, news_date, history_days)
322
+ previous_dj_rows, following_dj_rows = process_history(stock_index_dj_history, news_date, history_days)
323
+ previous_inx_rows, following_inx_rows = process_history(stock_index_inx_history, news_date, history_days)
324
+ previous_ixic_rows, following_ixic_rows = process_history(stock_index_ixic_history, news_date, history_days)
325
 
326
 
327
+ previous_stock_inx_index_history.append(previous_inx_rows.values.tolist())
328
+ previous_stock_dj_index_history.append(previous_dj_rows.values.tolist())
329
+ previous_stock_ixic_index_history.append(previous_ixic_rows.values.tolist())
330
+ previous_stock_ndx_index_history.append(previous_ndx_rows.values.tolist())
331
 
332
+ following_stock_inx_index_history.append(following_inx_rows.values.tolist())
333
+ following_stock_dj_index_history.append(following_dj_rows.values.tolist())
334
+ following_stock_ixic_index_history.append(following_ixic_rows.values.tolist())
335
+ following_stock_ndx_index_history.append(following_ndx_rows.values.tolist())
336
 
337
 
338
+ if not stock_code or stock_code == '' or stock_code == 'NONE_SYMBOL_FOUND':
339
  # 个股补零逻辑
340
  previous_stock_history.append([[-1] * 6] * history_days)
341
  following_stock_history.append([[-1] * 6] * 3)
342
 
 
 
343
  else:
344
+ stock_code = stock_code.strip()
345
+ stock_history = get_stock_history(stock_code, news_date)
346
+
347
+ # 处理个股数据
348
+ previous_rows, following_rows = process_history(stock_history, news_date)
349
+ previous_stock_history.append(previous_rows.values.tolist())
350
+ following_stock_history.append(following_rows.values.tolist())
351
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
  return previous_stock_history, following_stock_history, \
354
  previous_stock_inx_index_history, previous_stock_dj_index_history, previous_stock_ixic_index_history, previous_stock_ndx_index_history, \
us_stock.py CHANGED
@@ -11,6 +11,8 @@ import requests
11
  import threading
12
  import asyncio
13
 
 
 
14
 
15
  logging.basicConfig(level=logging.INFO)
16
 
@@ -150,6 +152,32 @@ def reduce_columns(df, columns_to_keep):
150
  return df[columns_to_keep]
151
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  # 返回个股历史数据
154
  def get_stock_history(symbol, news_date, retries=10):
155
  # 定义重试间隔时间序列(秒)
@@ -326,7 +354,7 @@ def find_stock_codes_or_names(entities):
326
  # 检查 Symbol 列
327
  if entity_upper in all_symbols:
328
  stock_codes.add(entity_upper)
329
- print(f"Matched symbol: {entity_upper}")
330
 
331
  # 检查 Name 列,确保完整匹配而不是部分匹配
332
  for name, symbol in name_to_symbol.items():
@@ -336,7 +364,9 @@ def find_stock_codes_or_names(entities):
336
  stock_codes.add(symbol.upper())
337
  #print(f"Matched name/company: '{entity_lower}' in '{name}' -> {symbol.upper()}")
338
 
339
- print(f"Stock codes found: {stock_codes}")
 
 
340
  return list(stock_codes)
341
 
342
 
 
11
  import threading
12
  import asyncio
13
 
14
+ import yfinance
15
+
16
 
17
  logging.basicConfig(level=logging.INFO)
18
 
 
152
  return df[columns_to_keep]
153
 
154
 
155
+ # 创建缓存字典
156
+ _price_cache = {}
157
+
158
+ def get_last_minute_stock_price(symbol: str) -> float:
159
+ """获取股票最新价格,使用30分钟缓存"""
160
+ current_time = datetime.now()
161
+
162
+ # 检查缓存
163
+ if symbol in _price_cache:
164
+ cached_price, cached_time = _price_cache[symbol]
165
+ # 如果缓存时间在30分钟内,直接返回缓存的价格
166
+ if current_time - cached_time < timedelta(minutes=30):
167
+ return cached_price
168
+
169
+ # 缓存无效或不存在,从yfinance获取新数据
170
+ stock_data = yfinance.download(symbol, period='1d', interval='5m')
171
+ if stock_data.empty:
172
+ return -1.0
173
+
174
+ latest_price = float(stock_data['Close'].iloc[-1])
175
+
176
+ # 更新缓存
177
+ _price_cache[symbol] = (latest_price, current_time)
178
+
179
+ return latest_price
180
+
181
  # 返回个股历史数据
182
  def get_stock_history(symbol, news_date, retries=10):
183
  # 定义重试间隔时间序列(秒)
 
354
  # 检查 Symbol 列
355
  if entity_upper in all_symbols:
356
  stock_codes.add(entity_upper)
357
+ #print(f"Matched symbol: {entity_upper}")
358
 
359
  # 检查 Name 列,确保完整匹配而不是部分匹配
360
  for name, symbol in name_to_symbol.items():
 
364
  stock_codes.add(symbol.upper())
365
  #print(f"Matched name/company: '{entity_lower}' in '{name}' -> {symbol.upper()}")
366
 
367
+ #print(f"Stock codes found: {stock_codes}")
368
+ if not stock_codes:
369
+ return ['NONE_SYMBOL_FOUND']
370
  return list(stock_codes)
371
 
372