kristada673 commited on
Commit
de6e775
·
1 Parent(s): c31337d

Upload 19 files

Browse files
finnlp/data_processors/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (172 Bytes). View file
 
finnlp/data_processors/__pycache__/_base.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
finnlp/data_processors/__pycache__/yahoofinance.cpython-310.pyc ADDED
Binary file (4.66 kB). View file
 
finnlp/data_processors/_base.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import urllib
4
+ import zipfile
5
+ from datetime import *
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import stockstats
12
+ import talib
13
+
14
+ from finnlp.utils.config import BINANCE_BASE_URL
15
+ from finnlp.utils.config import TIME_ZONE_BERLIN
16
+ from finnlp.utils.config import TIME_ZONE_JAKARTA
17
+ from finnlp.utils.config import TIME_ZONE_PARIS
18
+ from finnlp.utils.config import TIME_ZONE_SELFDEFINED
19
+ from finnlp.utils.config import TIME_ZONE_SHANGHAI
20
+ from finnlp.utils.config import TIME_ZONE_USEASTERN
21
+ from finnlp.utils.config import USE_TIME_ZONE_SELFDEFINED
22
+ from finnlp.utils.config_tickers import CAC_40_TICKER
23
+ from finnlp.utils.config_tickers import CSI_300_TICKER
24
+ from finnlp.utils.config_tickers import DAX_30_TICKER
25
+ from finnlp.utils.config_tickers import DOW_30_TICKER
26
+ from finnlp.utils.config_tickers import HSI_50_TICKER
27
+ from finnlp.utils.config_tickers import LQ45_TICKER
28
+ from finnlp.utils.config_tickers import MDAX_50_TICKER
29
+ from finnlp.utils.config_tickers import NAS_100_TICKER
30
+ from finnlp.utils.config_tickers import SDAX_50_TICKER
31
+ from finnlp.utils.config_tickers import SP_500_TICKER
32
+ from finnlp.utils.config_tickers import SSE_50_TICKER
33
+ from finnlp.utils.config_tickers import TECDAX_TICKER
34
+
35
+
36
+ class _Base:
37
+ def __init__(
38
+ self,
39
+ data_source: str,
40
+ start_date: str,
41
+ end_date: str,
42
+ time_interval: str,
43
+ **kwargs,
44
+ ):
45
+ self.data_source: str = data_source
46
+ self.start_date: str = start_date
47
+ self.end_date: str = end_date
48
+ self.time_interval: str = time_interval # standard time_interval
49
+ # transferred_time_interval will be supported in the future.
50
+ # self.nonstandard_time_interval: str = self.calc_nonstandard_time_interval() # transferred time_interval of this processor
51
+ self.time_zone: str = ""
52
+ self.dataframe: pd.DataFrame = pd.DataFrame()
53
+ self.dictnumpy: dict = (
54
+ {}
55
+ ) # e.g., self.dictnumpy["open"] = np.array([1, 2, 3]), self.dictnumpy["close"] = np.array([1, 2, 3])
56
+
57
+ def download_data(self, ticker_list: List[str]):
58
+ pass
59
+
60
+ def clean_data(self):
61
+ if "date" in self.dataframe.columns.values.tolist():
62
+ self.dataframe.rename(columns={"date": "time"}, inplace=True)
63
+ if "datetime" in self.dataframe.columns.values.tolist():
64
+ self.dataframe.rename(columns={"datetime": "time"}, inplace=True)
65
+ if self.data_source == "ccxt":
66
+ self.dataframe.rename(columns={"index": "time"}, inplace=True)
67
+
68
+ if self.data_source == "ricequant":
69
+ """RiceQuant data is already cleaned, we only need to transform data format here.
70
+ No need for filling NaN data"""
71
+ self.dataframe.rename(columns={"order_book_id": "tic"}, inplace=True)
72
+ # raw df uses multi-index (tic,time), reset it to single index (time)
73
+ self.dataframe.reset_index(level=[0, 1], inplace=True)
74
+ # check if there is NaN values
75
+ assert not self.dataframe.isnull().values.any()
76
+ elif self.data_source == "baostock":
77
+ self.dataframe.rename(columns={"code": "tic"}, inplace=True)
78
+
79
+ self.dataframe.dropna(inplace=True)
80
+ # adjusted_close: adjusted close price
81
+ if "adjusted_close" not in self.dataframe.columns.values.tolist():
82
+ self.dataframe["adjusted_close"] = self.dataframe["close"]
83
+ self.dataframe.sort_values(by=["time", "tic"], inplace=True)
84
+ self.dataframe = self.dataframe[
85
+ [
86
+ "tic",
87
+ "time",
88
+ "open",
89
+ "high",
90
+ "low",
91
+ "close",
92
+ "adjusted_close",
93
+ "volume",
94
+ ]
95
+ ]
96
+
97
+ def fillna(self):
98
+ df = self.dataframe
99
+
100
+ dfcode = pd.DataFrame(columns=["tic"])
101
+ dfdate = pd.DataFrame(columns=["time"])
102
+
103
+ dfcode.tic = df.tic.unique()
104
+ dfdate.time = df.time.unique()
105
+ dfdate.sort_values(by="time", ascending=False, ignore_index=True, inplace=True)
106
+
107
+ # the old pandas may not support pd.merge(how="cross")
108
+ try:
109
+ df1 = pd.merge(dfcode, dfdate, how="cross")
110
+ except:
111
+ print("Please wait for a few seconds...")
112
+ df1 = pd.DataFrame(columns=["tic", "time"])
113
+ for i in range(dfcode.shape[0]):
114
+ for j in range(dfdate.shape[0]):
115
+ df1 = df1.append(
116
+ pd.DataFrame(
117
+ data={
118
+ "tic": dfcode.iat[i, 0],
119
+ "time": dfdate.iat[j, 0],
120
+ },
121
+ index=[(i + 1) * (j + 1) - 1],
122
+ )
123
+ )
124
+
125
+ df = pd.merge(df1, df, how="left", on=["tic", "time"])
126
+
127
+ # back fill missing data then front fill
128
+ df_new = pd.DataFrame(columns=df.columns)
129
+ for i in df.tic.unique():
130
+ df_tmp = df[df.tic == i].fillna(method="bfill").fillna(method="ffill")
131
+ df_new = pd.concat([df_new, df_tmp], ignore_index=True)
132
+
133
+ df_new = df_new.fillna(0)
134
+
135
+ # reshape dataframe
136
+ df_new = df_new.sort_values(by=["time", "tic"]).reset_index(drop=True)
137
+
138
+ print("Shape of DataFrame: ", df_new.shape)
139
+
140
+ self.dataframe = df_new
141
+
142
+ def get_trading_days(self, start: str, end: str) -> List[str]:
143
+ if self.data_source in [
144
+ "binance",
145
+ "ccxt",
146
+ "quantconnect",
147
+ "ricequant",
148
+ "tushare",
149
+ ]:
150
+ print(
151
+ f"Calculate get_trading_days not supported for {self.data_source} yet."
152
+ )
153
+ return None
154
+
155
+ # select_stockstats_talib: 0 (stockstats, default), or 1 (use talib). Users can choose the method.
156
+ # drop_na_timestep: 0 (not dropping timesteps that contain nan), or 1 (dropping timesteps that contain nan, default). Users can choose the method.
157
+ def add_technical_indicator(
158
+ self,
159
+ tech_indicator_list: List[str],
160
+ select_stockstats_talib: int = 0,
161
+ drop_na_timesteps: int = 1,
162
+ ):
163
+ """
164
+ calculate technical indicators
165
+ use stockstats/talib package to add technical inidactors
166
+ :param data: (df) pandas dataframe
167
+ :return: (df) pandas dataframe
168
+ """
169
+ if "date" in self.dataframe.columns.values.tolist():
170
+ self.dataframe.rename(columns={"date": "time"}, inplace=True)
171
+
172
+ if self.data_source == "ccxt":
173
+ self.dataframe.rename(columns={"index": "time"}, inplace=True)
174
+
175
+ self.dataframe.reset_index(drop=False, inplace=True)
176
+ if "level_1" in self.dataframe.columns:
177
+ self.dataframe.drop(columns=["level_1"], inplace=True)
178
+ if "level_0" in self.dataframe.columns and "tic" not in self.dataframe.columns:
179
+ self.dataframe.rename(columns={"level_0": "tic"}, inplace=True)
180
+ assert select_stockstats_talib in {0, 1}
181
+ print("tech_indicator_list: ", tech_indicator_list)
182
+ if select_stockstats_talib == 0: # use stockstats
183
+ stock = stockstats.StockDataFrame.retype(self.dataframe)
184
+ unique_ticker = stock.tic.unique()
185
+ for indicator in tech_indicator_list:
186
+ print("indicator: ", indicator)
187
+ indicator_df = pd.DataFrame()
188
+ for i in range(len(unique_ticker)):
189
+ try:
190
+ temp_indicator = stock[stock.tic == unique_ticker[i]][indicator]
191
+ temp_indicator = pd.DataFrame(temp_indicator)
192
+ temp_indicator["tic"] = unique_ticker[i]
193
+ temp_indicator["time"] = self.dataframe[
194
+ self.dataframe.tic == unique_ticker[i]
195
+ ]["time"].to_list()
196
+ indicator_df = pd.concat(
197
+ [indicator_df, temp_indicator],
198
+ axis=0,
199
+ join="outer",
200
+ ignore_index=True,
201
+ )
202
+ except Exception as e:
203
+ print(e)
204
+ if not indicator_df.empty:
205
+ self.dataframe = self.dataframe.merge(
206
+ indicator_df[["tic", "time", indicator]],
207
+ on=["tic", "time"],
208
+ how="left",
209
+ )
210
+ else: # use talib
211
+ final_df = pd.DataFrame()
212
+ for i in self.dataframe.tic.unique():
213
+ tic_df = self.dataframe[self.dataframe.tic == i]
214
+ (
215
+ tic_df.loc["macd"],
216
+ tic_df.loc["macd_signal"],
217
+ tic_df.loc["macd_hist"],
218
+ ) = talib.MACD(
219
+ tic_df["close"],
220
+ fastperiod=12,
221
+ slowperiod=26,
222
+ signalperiod=9,
223
+ )
224
+ tic_df.loc["rsi"] = talib.RSI(tic_df["close"], timeperiod=14)
225
+ tic_df.loc["cci"] = talib.CCI(
226
+ tic_df["high"],
227
+ tic_df["low"],
228
+ tic_df["close"],
229
+ timeperiod=14,
230
+ )
231
+ tic_df.loc["dx"] = talib.DX(
232
+ tic_df["high"],
233
+ tic_df["low"],
234
+ tic_df["close"],
235
+ timeperiod=14,
236
+ )
237
+ final_df = pd.concat([final_df, tic_df], axis=0, join="outer")
238
+ self.dataframe = final_df
239
+
240
+ self.dataframe.sort_values(by=["time", "tic"], inplace=True)
241
+ if drop_na_timesteps:
242
+ time_to_drop = self.dataframe[
243
+ self.dataframe.isna().any(axis=1)
244
+ ].time.unique()
245
+ self.dataframe = self.dataframe[~self.dataframe.time.isin(time_to_drop)]
246
+ print("Succesfully add technical indicators")
247
+
248
+ def add_turbulence(self):
249
+ """
250
+ add turbulence index from a precalcualted dataframe
251
+ :param data: (df) pandas dataframe
252
+ :return: (df) pandas dataframe
253
+ """
254
+ # df = data.copy()
255
+ # turbulence_index = self.calculate_turbulence(df)
256
+ # df = df.merge(turbulence_index, on="time")
257
+ # df = df.sort_values(["time", "tic"]).reset_index(drop=True)
258
+ # return df
259
+ if self.data_source in [
260
+ "binance",
261
+ "ccxt",
262
+ "iexcloud",
263
+ "joinquant",
264
+ "quantconnect",
265
+ ]:
266
+ print(
267
+ f"Turbulence not supported for {self.data_source} yet. Return original DataFrame."
268
+ )
269
+ if self.data_source in [
270
+ "alpaca",
271
+ "ricequant",
272
+ "tushare",
273
+ "wrds",
274
+ "yahoofinance",
275
+ ]:
276
+ turbulence_index = self.calculate_turbulence()
277
+ self.dataframe = self.dataframe.merge(turbulence_index, on="time")
278
+ self.dataframe.sort_values(["time", "tic"], inplace=True)
279
+ self.dataframe.reset_index(drop=True, inplace=True)
280
+
281
+ def calculate_turbulence(self, time_period: int = 252) -> pd.DataFrame:
282
+ """calculate turbulence index based on dow 30"""
283
+ # can add other market assets
284
+ df_price_pivot = self.dataframe.pivot(
285
+ index="time", columns="tic", values="close"
286
+ )
287
+ # use returns to calculate turbulence
288
+ df_price_pivot = df_price_pivot.pct_change()
289
+
290
+ unique_date = self.dataframe["time"].unique()
291
+ # start after a year
292
+ start = time_period
293
+ turbulence_index = [0] * start
294
+ # turbulence_index = [0]
295
+ count = 0
296
+ for i in range(start, len(unique_date)):
297
+ current_price = df_price_pivot[df_price_pivot.index == unique_date[i]]
298
+ # use one year rolling window to calcualte covariance
299
+ hist_price = df_price_pivot[
300
+ (df_price_pivot.index < unique_date[i])
301
+ & (df_price_pivot.index >= unique_date[i - time_period])
302
+ ]
303
+ # Drop tickers which has number missing values more than the "oldest" ticker
304
+ filtered_hist_price = hist_price.iloc[
305
+ hist_price.isna().sum().min() :
306
+ ].dropna(axis=1)
307
+
308
+ cov_temp = filtered_hist_price.cov()
309
+ current_temp = current_price[list(filtered_hist_price)] - np.mean(
310
+ filtered_hist_price, axis=0
311
+ )
312
+ # cov_temp = hist_price.cov()
313
+ # current_temp=(current_price - np.mean(hist_price,axis=0))
314
+
315
+ temp = current_temp.values.dot(np.linalg.pinv(cov_temp)).dot(
316
+ current_temp.values.T
317
+ )
318
+ if temp > 0:
319
+ count += 1
320
+ # avoid large outlier because of the calculation just begins: else turbulence_temp = 0
321
+ turbulence_temp = temp[0][0] if count > 2 else 0
322
+ else:
323
+ turbulence_temp = 0
324
+ turbulence_index.append(turbulence_temp)
325
+
326
+ turbulence_index = pd.DataFrame(
327
+ {"time": df_price_pivot.index, "turbulence": turbulence_index}
328
+ )
329
+ return turbulence_index
330
+
331
+ def add_vix(self):
332
+ """
333
+ add vix from processors
334
+ :param data: (df) pandas dataframe
335
+ :return: (df) pandas dataframe
336
+ """
337
+ if self.data_source in [
338
+ "binance",
339
+ "ccxt",
340
+ "iexcloud",
341
+ "joinquant",
342
+ "quantconnect",
343
+ "ricequant",
344
+ "tushare",
345
+ ]:
346
+ print(
347
+ f"VIX is not applicable for {self.data_source}. Return original DataFrame"
348
+ )
349
+ return None
350
+
351
+ # if self.data_source == 'yahoofinance':
352
+ # df = data.copy()
353
+ # df_vix = self.download_data(
354
+ # start_date=df.time.min(),
355
+ # end_date=df.time.max(),
356
+ # ticker_list=["^VIX"],
357
+ # time_interval=self.time_interval,
358
+ # )
359
+ # df_vix = self.clean_data(df_vix)
360
+ # vix = df_vix[["time", "adjusted_close"]]
361
+ # vix.columns = ["time", "vix"]
362
+ #
363
+ # df = df.merge(vix, on="time")
364
+ # df = df.sort_values(["time", "tic"]).reset_index(drop=True)
365
+ # elif self.data_source == 'alpaca':
366
+ # vix_df = self.download_data(["VIXY"], self.start, self.end, self.time_interval)
367
+ # cleaned_vix = self.clean_data(vix_df)
368
+ # vix = cleaned_vix[["time", "close"]]
369
+ # vix = vix.rename(columns={"close": "VIXY"})
370
+ #
371
+ # df = data.copy()
372
+ # df = df.merge(vix, on="time")
373
+ # df = df.sort_values(["time", "tic"]).reset_index(drop=True)
374
+ # elif self.data_source == 'wrds':
375
+ # vix_df = self.download_data(['vix'], self.start, self.end_date, self.time_interval)
376
+ # cleaned_vix = self.clean_data(vix_df)
377
+ # vix = cleaned_vix[['date', 'close']]
378
+ #
379
+ # df = data.copy()
380
+ # df = df.merge(vix, on="date")
381
+ # df = df.sort_values(["date", "tic"]).reset_index(drop=True)
382
+
383
+ elif self.data_source == "yahoofinance":
384
+ ticker = "^VIX"
385
+ elif self.data_source == "alpaca":
386
+ ticker = "VIXY"
387
+ elif self.data_source == "wrds":
388
+ ticker = "vix"
389
+ else:
390
+ pass
391
+ df = self.dataframe.copy()
392
+ self.dataframe = [ticker]
393
+ # self.download_data(self.start_date, self.end_date, self.time_interval)
394
+ self.download_data([ticker], save_path="./data/vix.csv")
395
+ self.clean_data()
396
+ cleaned_vix = self.dataframe
397
+ # .rename(columns={ticker: "vix"})
398
+ vix = cleaned_vix[["time", "close"]]
399
+ cleaned_vix = vix.rename(columns={"close": "vix"})
400
+
401
+ df = df.merge(cleaned_vix, on="time")
402
+ df = df.sort_values(["time", "tic"]).reset_index(drop=True)
403
+ self.dataframe = df
404
+
405
+ def df_to_array(self, tech_indicator_list: List[str], if_vix: bool):
406
+ unique_ticker = self.dataframe.tic.unique()
407
+ price_array = np.column_stack(
408
+ [self.dataframe[self.dataframe.tic == tic].close for tic in unique_ticker]
409
+ )
410
+ common_tech_indicator_list = [
411
+ i
412
+ for i in tech_indicator_list
413
+ if i in self.dataframe.columns.values.tolist()
414
+ ]
415
+ tech_array = np.hstack(
416
+ [
417
+ self.dataframe.loc[
418
+ (self.dataframe.tic == tic), common_tech_indicator_list
419
+ ]
420
+ for tic in unique_ticker
421
+ ]
422
+ )
423
+ if if_vix:
424
+ risk_array = np.column_stack(
425
+ [self.dataframe[self.dataframe.tic == tic].vix for tic in unique_ticker]
426
+ )
427
+ else:
428
+ risk_array = (
429
+ np.column_stack(
430
+ [
431
+ self.dataframe[self.dataframe.tic == tic].turbulence
432
+ for tic in unique_ticker
433
+ ]
434
+ )
435
+ if "turbulence" in self.dataframe.columns
436
+ else None
437
+ )
438
+ print("Successfully transformed into array")
439
+ return price_array, tech_array, risk_array
440
+
441
+ # standard_time_interval s: second, m: minute, h: hour, d: day, w: week, M: month, q: quarter, y: year
442
+ # output time_interval of the processor
443
+ def calc_nonstandard_time_interval(self) -> str:
444
+ if self.data_source == "alpaca":
445
+ pass
446
+ elif self.data_source == "baostock":
447
+ # nonstandard_time_interval: 默认为d,日k线;d=日k线、w=周、m=月、5=5分钟、15=15分钟、30=30分钟、60=60分钟k线数据,不区分大小写;指数没有分钟线数据;周线每周最后一个交易日才可以获取,月线每月最后一个交易日才可以获取。
448
+ pass
449
+ time_intervals = ["5m", "15m", "30m", "60m", "1d", "1w", "1M"]
450
+ assert self.time_interval in time_intervals, (
451
+ "This time interval is not supported. Supported time intervals: "
452
+ + ",".join(time_intervals)
453
+ )
454
+ if (
455
+ "d" in self.time_interval
456
+ or "w" in self.time_interval
457
+ or "M" in self.time_interval
458
+ ):
459
+ return self.time_interval[-1:].lower()
460
+ elif "m" in self.time_interval:
461
+ return self.time_interval[:-1]
462
+ elif self.data_source == "binance":
463
+ # nonstandard_time_interval: 1m,3m,5m,15m,30m,1h,2h,4h,6h,8h,12h,1d,3d,1w,1M
464
+ time_intervals = [
465
+ "1m",
466
+ "3m",
467
+ "5m",
468
+ "15m",
469
+ "30m",
470
+ "1h",
471
+ "2h",
472
+ "4h",
473
+ "6h",
474
+ "8h",
475
+ "12h",
476
+ "1d",
477
+ "3d",
478
+ "1w",
479
+ "1M",
480
+ ]
481
+ assert self.time_interval in time_intervals, (
482
+ "This time interval is not supported. Supported time intervals: "
483
+ + ",".join(time_intervals)
484
+ )
485
+ return self.time_interval
486
+ elif self.data_source == "ccxt":
487
+ pass
488
+ elif self.data_source == "iexcloud":
489
+ time_intervals = ["1d"]
490
+ assert self.time_interval in time_intervals, (
491
+ "This time interval is not supported. Supported time intervals: "
492
+ + ",".join(time_intervals)
493
+ )
494
+ return self.time_interval.upper()
495
+ elif self.data_source == "joinquant":
496
+ # '1m', '5m', '15m', '30m', '60m', '120m', '1d', '1w', '1M'
497
+ time_intervals = [
498
+ "1m",
499
+ "5m",
500
+ "15m",
501
+ "30m",
502
+ "60m",
503
+ "120m",
504
+ "1d",
505
+ "1w",
506
+ "1M",
507
+ ]
508
+ assert self.time_interval in time_intervals, (
509
+ "This time interval is not supported. Supported time intervals: "
510
+ + ",".join(time_intervals)
511
+ )
512
+ return self.time_interval
513
+ elif self.data_source == "quantconnect":
514
+ pass
515
+ elif self.data_source == "ricequant":
516
+ # nonstandard_time_interval: 'd' - 天,'w' - 周,'m' - 月, 'q' - 季,'y' - 年
517
+ time_intervals = ["d", "w", "M", "q", "y"]
518
+ assert self.time_interval[-1] in time_intervals, (
519
+ "This time interval is not supported. Supported time intervals: "
520
+ + ",".join(time_intervals)
521
+ )
522
+ if "M" in self.time_interval:
523
+ return self.time_interval.lower()
524
+ else:
525
+ return self.time_interval
526
+ elif self.data_source == "tushare":
527
+ # 分钟频度包括1分、5、15、30、60分数据. Not support currently.
528
+ # time_intervals = ["1m", "5m", "15m", "30m", "60m", "1d"]
529
+ time_intervals = ["1d"]
530
+ assert self.time_interval in time_intervals, (
531
+ "This time interval is not supported. Supported time intervals: "
532
+ + ",".join(time_intervals)
533
+ )
534
+ return self.time_interval
535
+ elif self.data_source == "wrds":
536
+ pass
537
+ elif self.data_source == "yahoofinance":
538
+ # nonstandard_time_interval: ["1m", "2m", "5m", "15m", "30m", "60m", "90m", "1h", "1d", "5d","1wk", "1mo", "3mo"]
539
+ time_intervals = [
540
+ "1m",
541
+ "2m",
542
+ "5m",
543
+ "15m",
544
+ "30m",
545
+ "60m",
546
+ "90m",
547
+ "1h",
548
+ "1d",
549
+ "5d",
550
+ "1w",
551
+ "1M",
552
+ "3M",
553
+ ]
554
+ assert self.time_interval in time_intervals, (
555
+ "This time interval is not supported. Supported time intervals: "
556
+ + ",".join(time_intervals)
557
+ )
558
+ if "w" in self.time_interval:
559
+ return self.time_interval + "k"
560
+ elif "M" in self.time_interval:
561
+ return self.time_interval[:-1] + "mo"
562
+ else:
563
+ return self.time_interval
564
+ else:
565
+ raise ValueError(
566
+ f"Not support transfer_standard_time_interval for {self.data_source}"
567
+ )
568
+
569
+ # "600000.XSHG" -> "sh.600000"
570
+ # "000612.XSHE" -> "sz.000612"
571
+ def transfer_standard_ticker_to_nonstandard(self, ticker: str) -> str:
572
+ return ticker
573
+
574
+ def save_data(self, path):
575
+ if ".csv" in path:
576
+ path = path.split("/")
577
+ filename = path[-1]
578
+ path = "/".join(path[:-1] + [""])
579
+ else:
580
+ if path[-1] == "/":
581
+ filename = "dataset.csv"
582
+ else:
583
+ filename = "/dataset.csv"
584
+
585
+ os.makedirs(path, exist_ok=True)
586
+ self.dataframe.to_csv(path + filename, index=False)
587
+
588
+ def load_data(self, path):
589
+ assert ".csv" in path # only support csv format now
590
+ self.dataframe = pd.read_csv(path)
591
+ columns = self.dataframe.columns
592
+ print(f"{path} loaded")
593
+ # # check loaded file
594
+ # assert "date" in columns or "time" in columns
595
+ # assert "close" in columns
596
+
597
+
598
+ def calc_time_zone(
599
+ ticker_list: List[str],
600
+ time_zone_selfdefined: str,
601
+ use_time_zone_selfdefined: int,
602
+ ) -> str:
603
+ assert isinstance(ticker_list, list)
604
+ ticker_list = ticker_list[0]
605
+ if use_time_zone_selfdefined == 1:
606
+ time_zone = time_zone_selfdefined
607
+ elif ticker_list in HSI_50_TICKER + SSE_50_TICKER + CSI_300_TICKER:
608
+ time_zone = TIME_ZONE_SHANGHAI
609
+ elif ticker_list in DOW_30_TICKER + NAS_100_TICKER + SP_500_TICKER:
610
+ time_zone = TIME_ZONE_USEASTERN
611
+ elif ticker_list == CAC_40_TICKER:
612
+ time_zone = TIME_ZONE_PARIS
613
+ elif ticker_list in DAX_30_TICKER + TECDAX_TICKER + MDAX_50_TICKER + SDAX_50_TICKER:
614
+ time_zone = TIME_ZONE_BERLIN
615
+ elif ticker_list == LQ45_TICKER:
616
+ time_zone = TIME_ZONE_JAKARTA
617
+ else:
618
+ # hack needed to have this working with vix indicator
619
+ # fix: unable to set time_zone_selfdefined from top-level dataprocessor class
620
+ time_zone = TIME_ZONE_USEASTERN
621
+ # raise ValueError("Time zone is wrong.")
622
+ return time_zone
623
+
624
+
625
+ def check_date(d: str) -> bool:
626
+ assert (
627
+ len(d) == 10
628
+ ), "Please check the length of date and use the correct date like 2020-01-01."
629
+ indices = [0, 1, 2, 3, 5, 6, 8, 9]
630
+ correct = True
631
+ for i in indices:
632
+ if not d[i].isdigit():
633
+ correct = False
634
+ break
635
+ if not correct:
636
+ raise ValueError("Please use the correct date like 2020-01-01.")
637
+ return correct
finnlp/data_processors/akshare.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import time
4
+ import warnings
5
+
6
+ warnings.filterwarnings("ignore")
7
+ from typing import List
8
+
9
+ import pandas as pd
10
+ from tqdm import tqdm
11
+
12
+ import stockstats
13
+ import talib
14
+ from meta.data_processors._base import _Base
15
+
16
+ import akshare as ak # pip install akshare
17
+
18
+
19
+ class Akshare(_Base):
20
+ def __init__(
21
+ self,
22
+ data_source: str,
23
+ start_date: str,
24
+ end_date: str,
25
+ time_interval: str,
26
+ **kwargs,
27
+ ):
28
+ start_date = self.transfer_date(start_date)
29
+ end_date = self.transfer_date(end_date)
30
+
31
+ super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
32
+
33
+ if "adj" in kwargs.keys():
34
+ self.adj = kwargs["adj"]
35
+ print(f"Using {self.adj} method.")
36
+ else:
37
+ self.adj = ""
38
+
39
+ if "period" in kwargs.keys():
40
+ self.period = kwargs["period"]
41
+ else:
42
+ self.period = "daily"
43
+
44
+ def get_data(self, id) -> pd.DataFrame:
45
+ return ak.stock_zh_a_hist(
46
+ symbol=id,
47
+ period=self.time_interval,
48
+ start_date=self.start_date,
49
+ end_date=self.end_date,
50
+ adjust=self.adj,
51
+ )
52
+
53
+ def download_data(
54
+ self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
55
+ ):
56
+ """
57
+ `pd.DataFrame`
58
+ 7 columns: A tick symbol, time, open, high, low, close and volume
59
+ for the specified stock ticker
60
+ """
61
+ assert self.time_interval in [
62
+ "daily",
63
+ "weekly",
64
+ "monthly",
65
+ ], "Not supported currently"
66
+
67
+ self.ticker_list = ticker_list
68
+
69
+ self.dataframe = pd.DataFrame()
70
+ for i in tqdm(ticker_list, total=len(ticker_list)):
71
+ nonstandard_id = self.transfer_standard_ticker_to_nonstandard(i)
72
+ df_temp = self.get_data(nonstandard_id)
73
+ df_temp["tic"] = i
74
+ # df_temp = self.get_data(i)
75
+ self.dataframe = pd.concat([self.dataframe, df_temp])
76
+ # self.dataframe = self.dataframe.append(df_temp)
77
+ # print("{} ok".format(i))
78
+ time.sleep(0.25)
79
+
80
+ self.dataframe.columns = [
81
+ "time",
82
+ "open",
83
+ "close",
84
+ "high",
85
+ "low",
86
+ "volume",
87
+ "amount",
88
+ "amplitude",
89
+ "pct_chg",
90
+ "change",
91
+ "turnover",
92
+ "tic",
93
+ ]
94
+
95
+ self.dataframe.sort_values(by=["time", "tic"], inplace=True)
96
+ self.dataframe.reset_index(drop=True, inplace=True)
97
+
98
+ self.dataframe = self.dataframe[
99
+ ["tic", "time", "open", "high", "low", "close", "volume"]
100
+ ]
101
+ # self.dataframe.loc[:, 'tic'] = pd.DataFrame((self.dataframe['tic'].tolist()))
102
+ self.dataframe["time"] = pd.to_datetime(
103
+ self.dataframe["time"], format="%Y-%m-%d"
104
+ )
105
+ self.dataframe["day"] = self.dataframe["time"].dt.dayofweek
106
+ self.dataframe["time"] = self.dataframe.time.apply(
107
+ lambda x: x.strftime("%Y-%m-%d")
108
+ )
109
+
110
+ self.dataframe.dropna(inplace=True)
111
+ self.dataframe.sort_values(by=["time", "tic"], inplace=True)
112
+ self.dataframe.reset_index(drop=True, inplace=True)
113
+
114
+ self.save_data(save_path)
115
+
116
+ print(
117
+ f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
118
+ )
119
+
120
+ def data_split(self, df, start, end, target_date_col="time"):
121
+ """
122
+ split the dataset into training or testing using time
123
+ :param data: (df) pandas dataframe, start, end
124
+ :return: (df) pandas dataframe
125
+ """
126
+ data = df[(df[target_date_col] >= start) & (df[target_date_col] < end)]
127
+ data = data.sort_values([target_date_col, "tic"], ignore_index=True)
128
+ data.index = data[target_date_col].factorize()[0]
129
+ return data
130
+
131
+ def transfer_standard_ticker_to_nonstandard(self, ticker: str) -> str:
132
+ # "600000.XSHG" -> "600000"
133
+ # "000612.XSHE" -> "000612"
134
+ # "600000.SH" -> "600000"
135
+ # "000612.SZ" -> "000612"
136
+ if "." in ticker:
137
+ n, alpha = ticker.split(".")
138
+ # assert alpha in ["XSHG", "XSHE"], "Wrong alpha"
139
+ return n
140
+
141
+ def transfer_date(self, time: str) -> str:
142
+ if "-" in time:
143
+ time = "".join(time.split("-"))
144
+ elif "." in time:
145
+ time = "".join(time.split("."))
146
+ elif "/" in time:
147
+ time = "".join(time.split("/"))
148
+ return time
finnlp/data_processors/alpaca.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import alpaca_trade_api as tradeapi
4
+ import numpy as np
5
+ import pandas as pd
6
+ import pytz
7
+
8
+ try:
9
+ import exchange_calendars as tc
10
+ except:
11
+ print(
12
+ "Cannot import exchange_calendars.",
13
+ "If you are using python>=3.7, please install it.",
14
+ )
15
+ import trading_calendars as tc
16
+
17
+ print("Use trading_calendars instead for alpaca processor.")
18
+ # from basic_processor import _Base
19
+ from meta.data_processors._base import _Base
20
+ from meta.data_processors._base import calc_time_zone
21
+
22
+ from meta.config import (
23
+ TIME_ZONE_SHANGHAI,
24
+ TIME_ZONE_USEASTERN,
25
+ TIME_ZONE_PARIS,
26
+ TIME_ZONE_BERLIN,
27
+ TIME_ZONE_JAKARTA,
28
+ TIME_ZONE_SELFDEFINED,
29
+ USE_TIME_ZONE_SELFDEFINED,
30
+ BINANCE_BASE_URL,
31
+ )
32
+
33
+
34
+ class Alpaca(_Base):
35
+ # def __init__(self, API_KEY=None, API_SECRET=None, API_BASE_URL=None, api=None):
36
+ # if api is None:
37
+ # try:
38
+ # self.api = tradeapi.REST(API_KEY, API_SECRET, API_BASE_URL, "v2")
39
+ # except BaseException:
40
+ # raise ValueError("Wrong Account Info!")
41
+ # else:
42
+ # self.api = api
43
+ def __init__(
44
+ self,
45
+ data_source: str,
46
+ start_date: str,
47
+ end_date: str,
48
+ time_interval: str,
49
+ **kwargs,
50
+ ):
51
+ super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
52
+ if kwargs["API"] is None:
53
+ try:
54
+ self.api = tradeapi.REST(
55
+ kwargs["API_KEY"],
56
+ kwargs["API_SECRET"],
57
+ kwargs["API_BASE_URL"],
58
+ "v2",
59
+ )
60
+ except BaseException:
61
+ raise ValueError("Wrong Account Info!")
62
+ else:
63
+ self.api = kwargs["API"]
64
+
65
+ def download_data(
66
+ self,
67
+ ticker_list,
68
+ start_date,
69
+ end_date,
70
+ time_interval,
71
+ save_path: str = "./data/dataset.csv",
72
+ ) -> pd.DataFrame:
73
+ self.time_zone = calc_time_zone(
74
+ ticker_list, TIME_ZONE_SELFDEFINED, USE_TIME_ZONE_SELFDEFINED
75
+ )
76
+ start_date = pd.Timestamp(self.start_date, tz=self.time_zone)
77
+ end_date = pd.Timestamp(self.end_date, tz=self.time_zone) + pd.Timedelta(days=1)
78
+ self.time_interval = time_interval
79
+
80
+ date = start_date
81
+ data_df = pd.DataFrame()
82
+ while date != end_date:
83
+ start_time = (date + pd.Timedelta("09:30:00")).isoformat()
84
+ end_time = (date + pd.Timedelta("15:59:00")).isoformat()
85
+ for tic in ticker_list:
86
+ barset = self.api.get_bars(
87
+ tic,
88
+ time_interval,
89
+ start=start_time,
90
+ end=end_time,
91
+ limit=500,
92
+ ).df
93
+ barset["tic"] = tic
94
+ barset = barset.reset_index()
95
+ data_df = data_df.append(barset)
96
+ print(("Data before ") + end_time + " is successfully fetched")
97
+ # print(data_df.head())
98
+ date = date + pd.Timedelta(days=1)
99
+ if date.isoformat()[-14:-6] == "01:00:00":
100
+ date = date - pd.Timedelta("01:00:00")
101
+ elif date.isoformat()[-14:-6] == "23:00:00":
102
+ date = date + pd.Timedelta("01:00:00")
103
+ if date.isoformat()[-14:-6] != "00:00:00":
104
+ raise ValueError("Timezone Error")
105
+
106
+ data_df["time"] = data_df["timestamp"].apply(
107
+ lambda x: x.strftime("%Y-%m-%d %H:%M:%S")
108
+ )
109
+ self.dataframe = data_df
110
+
111
+ self.save_data(save_path)
112
+
113
+ print(
114
+ f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
115
+ )
116
+
117
+ def clean_data(self):
118
+ df = self.dataframe.copy()
119
+ tic_list = np.unique(df.tic.values)
120
+
121
+ trading_days = self.get_trading_days(start=self.start, end=self.end)
122
+ # produce full time index
123
+ times = []
124
+ for day in trading_days:
125
+ current_time = pd.Timestamp(day + " 09:30:00").tz_localize(self.time_zone)
126
+ for _ in range(390):
127
+ times.append(current_time)
128
+ current_time += pd.Timedelta(minutes=1)
129
+ # create a new dataframe with full time series
130
+ new_df = pd.DataFrame()
131
+ for tic in tic_list:
132
+ tmp_df = pd.DataFrame(
133
+ columns=["open", "high", "low", "close", "volume"], index=times
134
+ )
135
+ tic_df = df[df.tic == tic]
136
+ for i in range(tic_df.shape[0]):
137
+ tmp_df.loc[tic_df.iloc[i]["time"]] = tic_df.iloc[i][
138
+ ["open", "high", "low", "close", "volume"]
139
+ ]
140
+
141
+ # if the close price of the first row is NaN
142
+ if str(tmp_df.iloc[0]["close"]) == "nan":
143
+ print(
144
+ "The price of the first row for ticker ",
145
+ tic,
146
+ " is NaN. ",
147
+ "It will filled with the first valid price.",
148
+ )
149
+ for i in range(tmp_df.shape[0]):
150
+ if str(tmp_df.iloc[i]["close"]) != "nan":
151
+ first_valid_price = tmp_df.iloc[i]["close"]
152
+ tmp_df.iloc[0] = [
153
+ first_valid_price,
154
+ first_valid_price,
155
+ first_valid_price,
156
+ first_valid_price,
157
+ 0.0,
158
+ ]
159
+ break
160
+ # if the close price of the first row is still NaN (All the prices are NaN in this case)
161
+ if str(tmp_df.iloc[0]["close"]) == "nan":
162
+ print(
163
+ "Missing data for ticker: ",
164
+ tic,
165
+ " . The prices are all NaN. Fill with 0.",
166
+ )
167
+ tmp_df.iloc[0] = [
168
+ 0.0,
169
+ 0.0,
170
+ 0.0,
171
+ 0.0,
172
+ 0.0,
173
+ ]
174
+
175
+ # forward filling row by row
176
+ for i in range(tmp_df.shape[0]):
177
+ if str(tmp_df.iloc[i]["close"]) == "nan":
178
+ previous_close = tmp_df.iloc[i - 1]["close"]
179
+ if str(previous_close) == "nan":
180
+ raise ValueError
181
+ tmp_df.iloc[i] = [
182
+ previous_close,
183
+ previous_close,
184
+ previous_close,
185
+ previous_close,
186
+ 0.0,
187
+ ]
188
+ tmp_df = tmp_df.astype(float)
189
+ tmp_df["tic"] = tic
190
+ new_df = new_df.append(tmp_df)
191
+
192
+ new_df = new_df.reset_index()
193
+ new_df = new_df.rename(columns={"index": "time"})
194
+
195
+ print("Data clean finished!")
196
+
197
+ self.dataframe = new_df
198
+
199
+ # def add_technical_indicator(
200
+ # self,
201
+ # df,
202
+ # tech_indicator_list=[
203
+ # "macd",
204
+ # "boll_ub",
205
+ # "boll_lb",
206
+ # "rsi_30",
207
+ # "dx_30",
208
+ # "close_30_sma",
209
+ # "close_60_sma",
210
+ # ],
211
+ # ):
212
+ # df = df.rename(columns={"time": "date"})
213
+ # df = df.copy()
214
+ # df = df.sort_values(by=["tic", "date"])
215
+ # stock = Sdf.retype(df.copy())
216
+ # unique_ticker = stock.tic.unique()
217
+ # tech_indicator_list = tech_indicator_list
218
+ #
219
+ # for indicator in tech_indicator_list:
220
+ # indicator_df = pd.DataFrame()
221
+ # for i in range(len(unique_ticker)):
222
+ # # print(unique_ticker[i], i)
223
+ # temp_indicator = stock[stock.tic == unique_ticker[i]][indicator]
224
+ # temp_indicator = pd.DataFrame(temp_indicator)
225
+ # temp_indicator["tic"] = unique_ticker[i]
226
+ # # print(len(df[df.tic == unique_ticker[i]]['date'].to_list()))
227
+ # temp_indicator["date"] = df[df.tic == unique_ticker[i]][
228
+ # "date"
229
+ # ].to_list()
230
+ # indicator_df = indicator_df.append(temp_indicator, ignore_index=True)
231
+ # df = df.merge(
232
+ # indicator_df[["tic", "date", indicator]], on=["tic", "date"], how="left"
233
+ # )
234
+ # df = df.sort_values(by=["date", "tic"])
235
+ # df = df.rename(columns={"date": "time"})
236
+ # print("Succesfully add technical indicators")
237
+ # return df
238
+
239
+ # def add_vix(self, data):
240
+ # vix_df = self.download_data(["VIXY"], self.start, self.end, self.time_interval)
241
+ # cleaned_vix = self.clean_data(vix_df)
242
+ # vix = cleaned_vix[["time", "close"]]
243
+ # vix = vix.rename(columns={"close": "VIXY"})
244
+ #
245
+ # df = data.copy()
246
+ # df = df.merge(vix, on="time")
247
+ # df = df.sort_values(["time", "tic"]).reset_index(drop=True)
248
+ # return df
249
+
250
+ # def calculate_turbulence(self, data, time_period=252):
251
+ # # can add other market assets
252
+ # df = data.copy()
253
+ # df_price_pivot = df.pivot(index="date", columns="tic", values="close")
254
+ # # use returns to calculate turbulence
255
+ # df_price_pivot = df_price_pivot.pct_change()
256
+ #
257
+ # unique_date = df.date.unique()
258
+ # # start after a fixed time period
259
+ # start = time_period
260
+ # turbulence_index = [0] * start
261
+ # # turbulence_index = [0]
262
+ # count = 0
263
+ # for i in range(start, len(unique_date)):
264
+ # current_price = df_price_pivot[df_price_pivot.index == unique_date[i]]
265
+ # # use one year rolling window to calcualte covariance
266
+ # hist_price = df_price_pivot[
267
+ # (df_price_pivot.index < unique_date[i])
268
+ # & (df_price_pivot.index >= unique_date[i - time_period])
269
+ # ]
270
+ # # Drop tickers which has number missing values more than the "oldest" ticker
271
+ # filtered_hist_price = hist_price.iloc[
272
+ # hist_price.isna().sum().min() :
273
+ # ].dropna(axis=1)
274
+ #
275
+ # cov_temp = filtered_hist_price.cov()
276
+ # current_temp = current_price[[x for x in filtered_hist_price]] - np.mean(
277
+ # filtered_hist_price, axis=0
278
+ # )
279
+ # temp = current_temp.values.dot(np.linalg.pinv(cov_temp)).dot(
280
+ # current_temp.values.T
281
+ # )
282
+ # if temp > 0:
283
+ # count += 1
284
+ # if count > 2:
285
+ # turbulence_temp = temp[0][0]
286
+ # else:
287
+ # # avoid large outlier because of the calculation just begins
288
+ # turbulence_temp = 0
289
+ # else:
290
+ # turbulence_temp = 0
291
+ # turbulence_index.append(turbulence_temp)
292
+ #
293
+ # turbulence_index = pd.DataFrame(
294
+ # {"date": df_price_pivot.index, "turbulence": turbulence_index}
295
+ # )
296
+ # return turbulence_index
297
+ #
298
+ # def add_turbulence(self, data, time_period=252):
299
+ # """
300
+ # add turbulence index from a precalcualted dataframe
301
+ # :param data: (df) pandas dataframe
302
+ # :return: (df) pandas dataframe
303
+ # """
304
+ # df = data.copy()
305
+ # turbulence_index = self.calculate_turbulence(df, time_period=time_period)
306
+ # df = df.merge(turbulence_index, on="date")
307
+ # df = df.sort_values(["date", "tic"]).reset_index(drop=True)
308
+ # return df
309
+
310
+ # def df_to_array(self, df, tech_indicator_list, if_vix):
311
+ # df = df.copy()
312
+ # unique_ticker = df.tic.unique()
313
+ # if_first_time = True
314
+ # for tic in unique_ticker:
315
+ # if if_first_time:
316
+ # price_array = df[df.tic == tic][["close"]].values
317
+ # tech_array = df[df.tic == tic][tech_indicator_list].values
318
+ # if if_vix:
319
+ # turbulence_array = df[df.tic == tic]["VIXY"].values
320
+ # else:
321
+ # turbulence_array = df[df.tic == tic]["turbulence"].values
322
+ # if_first_time = False
323
+ # else:
324
+ # price_array = np.hstack(
325
+ # [price_array, df[df.tic == tic][["close"]].values]
326
+ # )
327
+ # tech_array = np.hstack(
328
+ # [tech_array, df[df.tic == tic][tech_indicator_list].values]
329
+ # )
330
+ # print("Successfully transformed into array")
331
+ # return price_array, tech_array, turbulence_array
332
+
333
+ def get_trading_days(self, start, end):
334
+ nyse = tc.get_calendar("NYSE")
335
+ df = nyse.sessions_in_range(
336
+ pd.Timestamp(start, tz=pytz.UTC), pd.Timestamp(end, tz=pytz.UTC)
337
+ )
338
+ return [str(day)[:10] for day in df]
339
+
340
+ def fetch_latest_data(
341
+ self, ticker_list, time_interval, tech_indicator_list, limit=100
342
+ ) -> pd.DataFrame:
343
+ data_df = pd.DataFrame()
344
+ for tic in ticker_list:
345
+ barset = self.api.get_barset([tic], time_interval, limit=limit).df[tic]
346
+ barset["tic"] = tic
347
+ barset = barset.reset_index()
348
+ data_df = data_df.append(barset)
349
+
350
+ data_df = data_df.reset_index(drop=True)
351
+ start_time = data_df.time.min()
352
+ end_time = data_df.time.max()
353
+ times = []
354
+ current_time = start_time
355
+ end = end_time + pd.Timedelta(minutes=1)
356
+ while current_time != end:
357
+ times.append(current_time)
358
+ current_time += pd.Timedelta(minutes=1)
359
+
360
+ df = data_df.copy()
361
+ new_df = pd.DataFrame()
362
+ for tic in ticker_list:
363
+ tmp_df = pd.DataFrame(
364
+ columns=["open", "high", "low", "close", "volume"], index=times
365
+ )
366
+ tic_df = df[df.tic == tic]
367
+ for i in range(tic_df.shape[0]):
368
+ tmp_df.loc[tic_df.iloc[i]["time"]] = tic_df.iloc[i][
369
+ ["open", "high", "low", "close", "volume"]
370
+ ]
371
+
372
+ if str(tmp_df.iloc[0]["close"]) == "nan":
373
+ for i in range(tmp_df.shape[0]):
374
+ if str(tmp_df.iloc[i]["close"]) != "nan":
375
+ first_valid_close = tmp_df.iloc[i]["close"]
376
+ tmp_df.iloc[0] = [
377
+ first_valid_close,
378
+ first_valid_close,
379
+ first_valid_close,
380
+ first_valid_close,
381
+ 0.0,
382
+ ]
383
+ break
384
+ if str(tmp_df.iloc[0]["close"]) == "nan":
385
+ print(
386
+ "Missing data for ticker: ",
387
+ tic,
388
+ " . The prices are all NaN. Fill with 0.",
389
+ )
390
+ tmp_df.iloc[0] = [
391
+ 0.0,
392
+ 0.0,
393
+ 0.0,
394
+ 0.0,
395
+ 0.0,
396
+ ]
397
+
398
+ for i in range(tmp_df.shape[0]):
399
+ if str(tmp_df.iloc[i]["close"]) == "nan":
400
+ previous_close = tmp_df.iloc[i - 1]["close"]
401
+ if str(previous_close) == "nan":
402
+ raise ValueError
403
+ tmp_df.iloc[i] = [
404
+ previous_close,
405
+ previous_close,
406
+ previous_close,
407
+ previous_close,
408
+ 0.0,
409
+ ]
410
+ tmp_df = tmp_df.astype(float)
411
+ tmp_df["tic"] = tic
412
+ new_df = new_df.append(tmp_df)
413
+
414
+ new_df = new_df.reset_index()
415
+ new_df = new_df.rename(columns={"index": "time"})
416
+
417
+ df = self.add_technical_indicator(new_df, tech_indicator_list)
418
+ df["VIXY"] = 0
419
+
420
+ price_array, tech_array, turbulence_array = self.df_to_array(
421
+ df, tech_indicator_list, if_vix=True
422
+ )
423
+ latest_price = price_array[-1]
424
+ latest_tech = tech_array[-1]
425
+ turb_df = self.api.get_barset(["VIXY"], time_interval, limit=1).df["VIXY"]
426
+ latest_turb = turb_df["close"].values
427
+ return latest_price, latest_tech, latest_turb
428
+
429
+ def get_portfolio_history(self, start, end):
430
+ trading_days = self.get_trading_days(start, end)
431
+ df = pd.DataFrame()
432
+ for day in trading_days:
433
+ df = df.append(
434
+ self.api.get_portfolio_history(
435
+ date_start=day, timeframe="5Min"
436
+ ).df.iloc[:79]
437
+ )
438
+ equities = df.equity.values
439
+ cumu_returns = equities / equities[0]
440
+ cumu_returns = cumu_returns[~np.isnan(cumu_returns)]
441
+ return cumu_returns
finnlp/data_processors/alphavantage.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import json
3
+ from typing import List
4
+
5
+ import pandas as pd
6
+ import requests
7
+
8
+ from meta.config import BINANCE_BASE_URL
9
+ from meta.config import TIME_ZONE_BERLIN
10
+ from meta.config import TIME_ZONE_JAKARTA
11
+ from meta.config import TIME_ZONE_PARIS
12
+ from meta.config import TIME_ZONE_SELFDEFINED
13
+ from meta.config import TIME_ZONE_SHANGHAI
14
+ from meta.config import TIME_ZONE_USEASTERN
15
+ from meta.config import USE_TIME_ZONE_SELFDEFINED
16
+ from meta.data_processors._base import _Base
17
+ from meta.data_processors._base import calc_time_zone
18
+
19
+
20
+ def transfer_date(d):
21
+ date = str(d.year)
22
+ date += "-"
23
+ if len(str(d.month)) == 1:
24
+ date += "0" + str(d.month)
25
+ else:
26
+ date += d.month
27
+ date += "-"
28
+ if len(str(d.day)) == 1:
29
+ date += "0" + str(d.day)
30
+ else:
31
+ date += str(d.day)
32
+ return date
33
+
34
+
35
+ class Alphavantage(_Base):
36
+ def __init__(
37
+ self,
38
+ data_source: str,
39
+ start_date: str,
40
+ end_date: str,
41
+ time_interval: str,
42
+ **kwargs,
43
+ ):
44
+ super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
45
+
46
+ assert time_interval == "1d", "please set the time_interval 1d"
47
+
48
+ # ["1d"]
49
+ def download_data(
50
+ self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
51
+ ):
52
+ # self.time_zone = calc_time_zone(
53
+ # ticker_list, TIME_ZONE_SELFDEFINED, USE_TIME_ZONE_SELFDEFINED
54
+ # )
55
+ self.dataframe = pd.DataFrame()
56
+ for ticker in ticker_list:
57
+ url = (
58
+ "https://www.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol="
59
+ + ticker
60
+ + "&apikey=demo"
61
+ )
62
+ r = requests.get(url)
63
+ data = r.json()
64
+ data2 = json.dumps(data["Time Series (Daily)"])
65
+ # gnData = json.dumps(data["Data"]["gn"])
66
+ df2 = pd.read_json(data2)
67
+ # gnDf = pd.read_json(gnData)
68
+
69
+ df3 = pd.DataFrame(df2.values.T, columns=df2.index, index=df2.columns)
70
+ df3.rename(
71
+ columns={
72
+ "1. open": "open",
73
+ "2. high": "high",
74
+ "3. low": "low",
75
+ "4. close": "close",
76
+ "5. volume": "volume",
77
+ },
78
+ inplace=True,
79
+ )
80
+ df3["tic"] = ticker
81
+ dates = [transfer_date(df2.index[i]) for i in range(len(df2.index))]
82
+ df3["date"] = dates
83
+ self.dataframe = pd.concat([self.dataframe, df3])
84
+ self.dataframe = self.dataframe.sort_values(by=["date", "tic"]).reset_index(
85
+ drop=True
86
+ )
87
+
88
+ self.save_data(save_path)
89
+
90
+ print(
91
+ f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
92
+ )
finnlp/data_processors/baostock.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import baostock as bs
4
+ import numpy as np
5
+ import pandas as pd
6
+ import pytz
7
+ import yfinance as yf
8
+
9
+ """Reference: https://github.com/AI4Finance-LLC/FinRL"""
10
+
11
+ try:
12
+ import exchange_calendars as tc
13
+ except:
14
+ print(
15
+ "Cannot import exchange_calendars.",
16
+ "If you are using python>=3.7, please install it.",
17
+ )
18
+ import trading_calendars as tc
19
+
20
+ print("Use trading_calendars instead for yahoofinance processor..")
21
+ # from basic_processor import _Base
22
+ from meta.data_processors._base import _Base
23
+ from meta.data_processors._base import calc_time_zone
24
+
25
+ from meta.config import (
26
+ TIME_ZONE_SHANGHAI,
27
+ TIME_ZONE_USEASTERN,
28
+ TIME_ZONE_PARIS,
29
+ TIME_ZONE_BERLIN,
30
+ TIME_ZONE_JAKARTA,
31
+ TIME_ZONE_SELFDEFINED,
32
+ USE_TIME_ZONE_SELFDEFINED,
33
+ BINANCE_BASE_URL,
34
+ )
35
+
36
+
37
+ class Baostock(_Base):
38
+ def __init__(
39
+ self,
40
+ data_source: str,
41
+ start_date: str,
42
+ end_date: str,
43
+ time_interval: str,
44
+ **kwargs,
45
+ ):
46
+ super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
47
+
48
+ # 日k线、周k线、月k线,以及5分钟、15分钟、30分钟和60分钟k线数据
49
+ # ["5m", "15m", "30m", "60m", "1d", "1w", "1M"]
50
+ def download_data(
51
+ self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
52
+ ):
53
+ lg = bs.login()
54
+ print("baostock login respond error_code:" + lg.error_code)
55
+ print("baostock login respond error_msg:" + lg.error_msg)
56
+
57
+ self.time_zone = calc_time_zone(
58
+ ticker_list, TIME_ZONE_SELFDEFINED, USE_TIME_ZONE_SELFDEFINED
59
+ )
60
+ self.dataframe = pd.DataFrame()
61
+ for ticker in ticker_list:
62
+ nonstandrad_ticker = self.transfer_standard_ticker_to_nonstandard(ticker)
63
+ # All supported: "date,code,open,high,low,close,preclose,volume,amount,adjustflag,turn,tradestatus,pctChg,isST"
64
+ rs = bs.query_history_k_data_plus(
65
+ nonstandrad_ticker,
66
+ "date,code,open,high,low,close,volume",
67
+ start_date=self.start_date,
68
+ end_date=self.end_date,
69
+ frequency=self.time_interval,
70
+ adjustflag="3",
71
+ )
72
+
73
+ print("baostock download_data respond error_code:" + rs.error_code)
74
+ print("baostock download_data respond error_msg:" + rs.error_msg)
75
+
76
+ data_list = []
77
+ while (rs.error_code == "0") & rs.next():
78
+ data_list.append(rs.get_row_data())
79
+ df = pd.DataFrame(data_list, columns=rs.fields)
80
+ df.loc[:, "code"] = [ticker] * df.shape[0]
81
+ self.dataframe = pd.concat([self.dataframe, df])
82
+ self.dataframe = self.dataframe.sort_values(by=["date", "code"]).reset_index(
83
+ drop=True
84
+ )
85
+ bs.logout()
86
+
87
+ self.dataframe.open = self.dataframe.open.astype(float)
88
+ self.dataframe.high = self.dataframe.high.astype(float)
89
+ self.dataframe.low = self.dataframe.low.astype(float)
90
+ self.dataframe.close = self.dataframe.close.astype(float)
91
+ self.save_data(save_path)
92
+
93
+ print(
94
+ f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
95
+ )
96
+
97
+ def get_trading_days(self, start, end):
98
+ lg = bs.login()
99
+ print("baostock login respond error_code:" + lg.error_code)
100
+ print("baostock login respond error_msg:" + lg.error_msg)
101
+ result = bs.query_trade_dates(start_date=start, end_date=end)
102
+ bs.logout()
103
+ return result
104
+
105
+ # "600000.XSHG" -> "sh.600000"
106
+ # "000612.XSHE" -> "sz.000612"
107
+ def transfer_standard_ticker_to_nonstandard(self, ticker: str) -> str:
108
+ n, alpha = ticker.split(".")
109
+ assert alpha in ["XSHG", "XSHE"], "Wrong alpha"
110
+ if alpha == "XSHG":
111
+ nonstandard_ticker = "sh." + n
112
+ elif alpha == "XSHE":
113
+ nonstandard_ticker = "sz." + n
114
+ return nonstandard_ticker
finnlp/data_processors/binance.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime as dt
2
+ import json
3
+ import os
4
+ import urllib
5
+ import zipfile
6
+ from datetime import *
7
+ from pathlib import Path
8
+ from typing import List
9
+
10
+ import pandas as pd
11
+ import requests
12
+
13
+ from meta.config import BINANCE_BASE_URL
14
+ from meta.config import TIME_ZONE_BERLIN
15
+ from meta.config import TIME_ZONE_JAKARTA
16
+ from meta.config import TIME_ZONE_PARIS
17
+ from meta.config import TIME_ZONE_SELFDEFINED
18
+ from meta.config import TIME_ZONE_SHANGHAI
19
+ from meta.config import TIME_ZONE_USEASTERN
20
+ from meta.config import USE_TIME_ZONE_SELFDEFINED
21
+ from meta.data_processors._base import _Base
22
+ from meta.data_processors._base import check_date
23
+
24
+ # from _base import check_date
25
+
26
+
27
+ class Binance(_Base):
28
+ def __init__(
29
+ self,
30
+ data_source: str,
31
+ start_date: str,
32
+ end_date: str,
33
+ time_interval: str,
34
+ **kwargs,
35
+ ):
36
+ if time_interval == "1D":
37
+ raise ValueError("Please use the time_interval 1d instead of 1D")
38
+ if time_interval == "1d":
39
+ check_date(start_date)
40
+ check_date(end_date)
41
+ super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
42
+ self.url = "https://api.binance.com/api/v3/klines"
43
+ self.time_diff = None
44
+
45
+ # main functions
46
+ def download_data(
47
+ self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
48
+ ):
49
+ startTime = dt.datetime.strptime(self.start_date, "%Y-%m-%d")
50
+ endTime = dt.datetime.strptime(self.end_date, "%Y-%m-%d")
51
+
52
+ self.start_time = self.stringify_dates(startTime)
53
+ self.end_time = self.stringify_dates(endTime)
54
+ self.interval = self.time_interval
55
+ self.limit = 1440
56
+
57
+ # 1s for now, will add support for variable time and variable tick soon
58
+ if self.time_interval == "1s":
59
+ # as per https://binance-docs.github.io/apidocs/spot/en/#compressed-aggregate-trades-list
60
+ self.limit = 1000
61
+ final_df = self.fetch_n_combine(self.start_date, self.end_date, ticker_list)
62
+ else:
63
+ final_df = pd.DataFrame()
64
+ for i in ticker_list:
65
+ hist_data = self.dataframe_with_limit(symbol=i)
66
+ df = hist_data.iloc[:-1].dropna()
67
+ df["tic"] = i
68
+ final_df = pd.concat([final_df, df], axis=0, join="outer")
69
+ self.dataframe = final_df
70
+
71
+ self.save_data(save_path)
72
+
73
+ print(
74
+ f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
75
+ )
76
+
77
+ # def clean_data(self, df):
78
+ # df = df.dropna()
79
+ # return df
80
+
81
+ # def add_technical_indicator(self, df, tech_indicator_list):
82
+ # print('Adding self-defined technical indicators is NOT supported yet.')
83
+ # print('Use default: MACD, RSI, CCI, DX.')
84
+ # self.tech_indicator_list = ['open', 'high', 'low', 'close', 'volume',
85
+ # 'macd', 'macd_signal', 'macd_hist',
86
+ # 'rsi', 'cci', 'dx']
87
+ # final_df = pd.DataFrame()
88
+ # for i in df.tic.unique():
89
+ # tic_df = df[df.tic==i]
90
+ # tic_df['macd'], tic_df['macd_signal'], tic_df['macd_hist'] = MACD(tic_df['close'], fastperiod=12,
91
+ # slowperiod=26, signalperiod=9)
92
+ # tic_df['rsi'] = RSI(tic_df['close'], timeperiod=14)
93
+ # tic_df['cci'] = CCI(tic_df['high'], tic_df['low'], tic_df['close'], timeperiod=14)
94
+ # tic_df['dx'] = DX(tic_df['high'], tic_df['low'], tic_df['close'], timeperiod=14)
95
+ # final_df = final_df.append(tic_df)
96
+ #
97
+ # return final_df
98
+
99
+ # def add_turbulence(self, df):
100
+ # print('Turbulence not supported yet. Return original DataFrame.')
101
+ #
102
+ # return df
103
+
104
+ # def add_vix(self, df):
105
+ # print('VIX is not applicable for cryptocurrencies. Return original DataFrame')
106
+ #
107
+ # return df
108
+
109
+ # def df_to_array(self, df, tech_indicator_list, if_vix):
110
+ # unique_ticker = df.tic.unique()
111
+ # price_array = np.column_stack([df[df.tic==tic].close for tic in unique_ticker])
112
+ # tech_array = np.hstack([df.loc[(df.tic==tic), tech_indicator_list] for tic in unique_ticker])
113
+ # assert price_array.shape[0] == tech_array.shape[0]
114
+ # return price_array, tech_array, np.array([])
115
+
116
+ # helper functions
117
+ def stringify_dates(self, date: dt.datetime):
118
+ return str(int(date.timestamp() * 1000))
119
+
120
+ def get_binance_bars(self, last_datetime, symbol):
121
+ """
122
+ klines api returns data in the following order:
123
+ open_time, open_price, high_price, low_price, close_price,
124
+ volume, close_time, quote_asset_volume, n_trades,
125
+ taker_buy_base_asset_volume, taker_buy_quote_asset_volume,
126
+ ignore
127
+ """
128
+ req_params = {
129
+ "symbol": symbol,
130
+ "interval": self.interval,
131
+ "startTime": last_datetime,
132
+ "endTime": self.end_time,
133
+ "limit": self.limit,
134
+ }
135
+ # For debugging purposes, uncomment these lines and if they throw an error
136
+ # then you may have an error in req_params
137
+ # r = requests.get(self.url, params=req_params)
138
+ # print(r.text)
139
+ df = pd.DataFrame(requests.get(self.url, params=req_params).json())
140
+
141
+ if df.empty:
142
+ return None
143
+
144
+ df = df.iloc[:, 0:6]
145
+ df.columns = ["datetime", "open", "high", "low", "close", "volume"]
146
+
147
+ df[["open", "high", "low", "close", "volume"]] = df[
148
+ ["open", "high", "low", "close", "volume"]
149
+ ].astype(float)
150
+
151
+ # No stock split and dividend announcement, hence adjusted close is the same as close
152
+ df["adjusted_close"] = df["close"]
153
+ df["datetime"] = df.datetime.apply(
154
+ lambda x: dt.datetime.fromtimestamp(x / 1000.0)
155
+ )
156
+ df.reset_index(drop=True, inplace=True)
157
+
158
+ return df
159
+
160
+ def get_newest_bars(self, symbols, interval, limit):
161
+ merged_df = pd.DataFrame()
162
+ for symbol in symbols:
163
+ req_params = {
164
+ "symbol": symbol,
165
+ "interval": interval,
166
+ "limit": limit,
167
+ }
168
+
169
+ df = pd.DataFrame(
170
+ requests.get(self.url, params=req_params).json(),
171
+ index=range(limit),
172
+ )
173
+
174
+ if df.empty:
175
+ return None
176
+
177
+ df = df.iloc[:, 0:6]
178
+ df.columns = ["datetime", "open", "high", "low", "close", "volume"]
179
+
180
+ df[["open", "high", "low", "close", "volume"]] = df[
181
+ ["open", "high", "low", "close", "volume"]
182
+ ].astype(float)
183
+
184
+ # No stock split and dividend announcement, hence adjusted close is the same as close
185
+ df["adjusted_close"] = df["close"]
186
+ df["datetime"] = df.datetime.apply(
187
+ lambda x: dt.datetime.fromtimestamp(x / 1000.0)
188
+ )
189
+ df["tic"] = symbol
190
+ df = df.rename(columns={"datetime": "time"})
191
+ df.reset_index(drop=True, inplace=True)
192
+ merged_df = merged_df.append(df)
193
+
194
+ return merged_df
195
+
196
+ def dataframe_with_limit(self, symbol):
197
+ final_df = pd.DataFrame()
198
+ last_datetime = self.start_time
199
+ while True:
200
+ new_df = self.get_binance_bars(last_datetime, symbol)
201
+ if new_df is None:
202
+ break
203
+
204
+ if last_datetime == self.end_time:
205
+ break
206
+
207
+ final_df = pd.concat([final_df, new_df], axis=0, join="outer")
208
+ # last_datetime = max(new_df.datetime) + dt.timedelta(days=1)
209
+ last_datetime = max(new_df.datetime)
210
+ if isinstance(last_datetime, pd.Timestamp):
211
+ last_datetime = last_datetime.to_pydatetime()
212
+
213
+ if self.time_diff == None:
214
+ self.time_diff = new_df.loc[1]["datetime"] - new_df.loc[0]["datetime"]
215
+
216
+ last_datetime = last_datetime + self.time_diff
217
+ last_datetime = self.stringify_dates(last_datetime)
218
+
219
+ date_value = final_df["datetime"].apply(
220
+ lambda x: x.strftime("%Y-%m-%d %H:%M:%S")
221
+ )
222
+ final_df.insert(0, "time", date_value)
223
+ final_df.drop("datetime", inplace=True, axis=1)
224
+ return final_df
225
+
226
+ def get_download_url(self, file_url):
227
+ return f"{BINANCE_BASE_URL}{file_url}"
228
+
229
+ # downloads zip, unzips zip and deltes zip
230
+ def download_n_unzip_file(self, base_path, file_name, date_range=None):
231
+ download_path = f"{base_path}{file_name}"
232
+ if date_range:
233
+ date_range = date_range.replace(" ", "_")
234
+ base_path = os.path.join(base_path, date_range)
235
+
236
+ # raw_cache_dir = get_destination_dir("./cache/tick_raw")
237
+ raw_cache_dir = "./cache/tick_raw"
238
+ zip_save_path = os.path.join(raw_cache_dir, file_name)
239
+
240
+ csv_name = os.path.splitext(file_name)[0] + ".csv"
241
+ csv_save_path = os.path.join(raw_cache_dir, csv_name)
242
+
243
+ fhandles = []
244
+
245
+ if os.path.exists(csv_save_path):
246
+ print(f"\nfile already exists! {csv_save_path}")
247
+ return [csv_save_path]
248
+
249
+ # make the "cache" directory (only)
250
+ if not os.path.exists(raw_cache_dir):
251
+ Path(raw_cache_dir).mkdir(parents=True, exist_ok=True)
252
+
253
+ try:
254
+ download_url = self.get_download_url(download_path)
255
+ dl_file = urllib.request.urlopen(download_url)
256
+ length = dl_file.getheader("content-length")
257
+ if length:
258
+ length = int(length)
259
+ blocksize = max(4096, length // 100)
260
+
261
+ with open(zip_save_path, "wb") as out_file:
262
+ dl_progress = 0
263
+ print(f"\nFile Download: {zip_save_path}")
264
+ while True:
265
+ buf = dl_file.read(blocksize)
266
+ if not buf:
267
+ break
268
+ out_file.write(buf)
269
+ # visuals
270
+ # dl_progress += len(buf)
271
+ # done = int(50 * dl_progress / length)
272
+ # sys.stdout.write("\r[%s%s]" % ('#' * done, '.' * (50-done)) )
273
+ # sys.stdout.flush()
274
+
275
+ # unzip and delete zip
276
+ file = zipfile.ZipFile(zip_save_path)
277
+ with zipfile.ZipFile(zip_save_path) as zip:
278
+ # guaranteed just 1 csv
279
+ csvpath = zip.extract(zip.namelist()[0], raw_cache_dir)
280
+ fhandles.append(csvpath)
281
+ os.remove(zip_save_path)
282
+ return fhandles
283
+
284
+ except urllib.error.HTTPError:
285
+ print(f"\nFile not found: {download_url}")
286
+
287
+ def convert_to_date_object(self, d):
288
+ year, month, day = [int(x) for x in d.split("-")]
289
+ return date(year, month, day)
290
+
291
+ def get_path(
292
+ self,
293
+ trading_type,
294
+ market_data_type,
295
+ time_period,
296
+ symbol,
297
+ interval=None,
298
+ ):
299
+ trading_type_path = "data/spot"
300
+ # currently just supporting spot
301
+ if trading_type != "spot":
302
+ trading_type_path = f"data/futures/{trading_type}"
303
+ return (
304
+ f"{trading_type_path}/{time_period}/{market_data_type}/{symbol.upper()}/{interval}/"
305
+ if interval is not None
306
+ else f"{trading_type_path}/{time_period}/{market_data_type}/{symbol.upper()}/"
307
+ )
308
+
309
+ # helpers for manipulating tick level data (1s intervals)
310
+ def download_daily_aggTrades(
311
+ self, symbols, num_symbols, dates, start_date, end_date
312
+ ):
313
+ trading_type = "spot"
314
+ date_range = start_date + " " + end_date
315
+ start_date = self.convert_to_date_object(start_date)
316
+ end_date = self.convert_to_date_object(end_date)
317
+
318
+ print(f"Found {num_symbols} symbols")
319
+
320
+ map = {}
321
+ for current, symbol in enumerate(symbols):
322
+ map[symbol] = []
323
+ print(
324
+ f"[{current + 1}/{num_symbols}] - start download daily {symbol} aggTrades "
325
+ )
326
+ for date in dates:
327
+ current_date = self.convert_to_date_object(date)
328
+ if current_date >= start_date and current_date <= end_date:
329
+ path = self.get_path(trading_type, "aggTrades", "daily", symbol)
330
+ file_name = f"{symbol.upper()}-aggTrades-{date}.zip"
331
+ fhandle = self.download_n_unzip_file(path, file_name, date_range)
332
+ map[symbol] += fhandle
333
+ return map
334
+
335
+ def fetch_aggTrades(self, startDate: str, endDate: str, tickers: List[str]):
336
+ # all valid symbols traded on v3 api
337
+ response = urllib.request.urlopen(
338
+ "https://api.binance.com/api/v3/exchangeInfo"
339
+ ).read()
340
+ valid_symbols = list(
341
+ map(
342
+ lambda symbol: symbol["symbol"],
343
+ json.loads(response)["symbols"],
344
+ )
345
+ )
346
+
347
+ for tic in tickers:
348
+ if tic not in valid_symbols:
349
+ print(tic + " not a valid ticker, removing from download")
350
+ tickers = list(set(tickers) & set(valid_symbols))
351
+ num_symbols = len(tickers)
352
+ # not adding tz yet
353
+ # for ffill missing data on starting on first day 00:00:00 (if any)
354
+ tminus1 = (self.convert_to_date_object(startDate) - dt.timedelta(1)).strftime(
355
+ "%Y-%m-%d"
356
+ )
357
+ dates = pd.date_range(start=tminus1, end=endDate)
358
+ dates = [date.strftime("%Y-%m-%d") for date in dates]
359
+ return self.download_daily_aggTrades(
360
+ tickers, num_symbols, dates, tminus1, endDate
361
+ )
362
+
363
+ # Dict[str]:List[str] -> pd.DataFrame
364
+ def combine_raw(self, map):
365
+ # same format as jingyang's current data format
366
+ final_df = pd.DataFrame()
367
+ # using AggTrades with headers from https://github.com/binance/binance-public-data/
368
+ colNames = [
369
+ "AggregatetradeId",
370
+ "Price",
371
+ "volume",
372
+ "FirsttradeId",
373
+ "LasttradeId",
374
+ "time",
375
+ "buyerWasMaker",
376
+ "tradeWasBestPriceMatch",
377
+ ]
378
+ for tic in map.keys():
379
+ security = pd.DataFrame()
380
+ for i, csv in enumerate(map[tic]):
381
+ dailyticks = pd.read_csv(
382
+ csv,
383
+ names=colNames,
384
+ index_col=["time"],
385
+ parse_dates=["time"],
386
+ date_parser=lambda epoch: pd.to_datetime(epoch, unit="ms"),
387
+ )
388
+ dailyfinal = dailyticks.resample("1s").agg(
389
+ {"Price": "ohlc", "volume": "sum"}
390
+ )
391
+ dailyfinal.columns = dailyfinal.columns.droplevel(0)
392
+ # favor continuous series
393
+ # dailyfinal.dropna(inplace=True)
394
+
395
+ # implemented T-1 day ffill day start missing values
396
+ # guaranteed first csv is tminus1 day
397
+ if i == 0:
398
+ tmr = dailyfinal.index[0].date() + dt.timedelta(1)
399
+ tmr_dt = dt.datetime.combine(tmr, dt.time.min)
400
+ last_time_stamp_dt = dailyfinal.index[-1].to_pydatetime()
401
+ s_delta = (tmr_dt - last_time_stamp_dt).seconds
402
+ lastsample = dailyfinal.iloc[-1:]
403
+ lastsample.index = lastsample.index.shift(s_delta, "s")
404
+ else:
405
+ day_dt = dailyfinal.index[0].date()
406
+ day_str = day_dt.strftime("%Y-%m-%d")
407
+ nextday_str = (day_dt + dt.timedelta(1)).strftime("%Y-%m-%d")
408
+ if dailyfinal.index[0].second != 0:
409
+ # append last sample
410
+ dailyfinal = lastsample.append(dailyfinal)
411
+ # otherwise, just reindex and ffill
412
+ dailyfinal = dailyfinal.reindex(
413
+ pd.date_range(day_str, nextday_str, freq="1s")[:-1],
414
+ method="ffill",
415
+ )
416
+ # save reference info (guaranteed to be :59)
417
+ lastsample = dailyfinal.iloc[-1:]
418
+ lastsample.index = lastsample.index.shift(1, "s")
419
+
420
+ if dailyfinal.shape[0] != 86400:
421
+ raise ValueError("everyday should have 86400 datapoints")
422
+
423
+ # only save real startDate - endDate
424
+ security = security.append(dailyfinal)
425
+
426
+ security.ffill(inplace=True)
427
+ security["tic"] = tic
428
+ final_df = final_df.append(security)
429
+ return final_df
430
+
431
+ def fetch_n_combine(self, startDate, endDate, tickers):
432
+ # return combine_raw(fetchAggTrades(startDate, endDate, tickers))
433
+ mapping = self.fetch_aggTrades(startDate, endDate, tickers)
434
+ return self.combine_raw(mapping)
finnlp/data_processors/ccxt.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import calendar
2
+ from datetime import datetime
3
+ from typing import List
4
+
5
+ import ccxt
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ from meta.data_processors._base import _Base
10
+
11
+ # from basic_processor import _Base
12
+
13
+
14
+ class Ccxt(_Base):
15
+ def __init__(
16
+ self,
17
+ data_source: str,
18
+ start_date: str,
19
+ end_date: str,
20
+ time_interval: str,
21
+ **kwargs,
22
+ ):
23
+ super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
24
+ self.binance = ccxt.binance()
25
+
26
+ def download_data(
27
+ self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
28
+ ):
29
+ crypto_column = pd.MultiIndex.from_product(
30
+ [ticker_list, ["open", "high", "low", "close", "volume"]]
31
+ )
32
+ first_time = True
33
+ for ticker in ticker_list:
34
+ start_dt = datetime.strptime(self.start_date, "%Y%m%d %H:%M:%S")
35
+ end_dt = datetime.strptime(self.end_date, "%Y%m%d %H:%M:%S")
36
+ start_timestamp = calendar.timegm(start_dt.utctimetuple())
37
+ end_timestamp = calendar.timegm(end_dt.utctimetuple())
38
+ if self.time_interval == "1Min":
39
+ date_list = [
40
+ datetime.utcfromtimestamp(float(time))
41
+ for time in range(start_timestamp, end_timestamp, 60 * 720)
42
+ ]
43
+ else:
44
+ date_list = [
45
+ datetime.utcfromtimestamp(float(time))
46
+ for time in range(start_timestamp, end_timestamp, 60 * 1440)
47
+ ]
48
+ df = self.ohlcv(date_list, ticker, self.time_interval)
49
+ if first_time:
50
+ dataset = pd.DataFrame(columns=crypto_column, index=df["time"].values)
51
+ first_time = False
52
+ temp_col = pd.MultiIndex.from_product(
53
+ [[ticker], ["open", "high", "low", "close", "volume"]]
54
+ )
55
+ dataset[temp_col] = df[["open", "high", "low", "close", "volume"]].values
56
+ print("Actual end time: " + str(df["time"].values[-1]))
57
+ self.dataframe = dataset
58
+
59
+ self.save_data(save_path)
60
+
61
+ print(
62
+ f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
63
+ )
64
+
65
+ # def add_technical_indicators(self, df, pair_list, tech_indicator_list = [
66
+ # 'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'dx_30',
67
+ # 'close_30_sma', 'close_60_sma']):
68
+ # df = df.dropna()
69
+ # df = df.copy()
70
+ # column_list = [pair_list, ['open','high','low','close','volume']+(tech_indicator_list)]
71
+ # column = pd.MultiIndex.from_product(column_list)
72
+ # index_list = df.index
73
+ # dataset = pd.DataFrame(columns=column,index=index_list)
74
+ # for pair in pair_list:
75
+ # pair_column = pd.MultiIndex.from_product([[pair],['open','high','low','close','volume']])
76
+ # dataset[pair_column] = df[pair]
77
+ # temp_df = df[pair].reset_index().sort_values(by=['index'])
78
+ # temp_df = temp_df.rename(columns={'index':'date'})
79
+ # crypto_df = Sdf.retype(temp_df.copy())
80
+ # for indicator in tech_indicator_list:
81
+ # temp_indicator = crypto_df[indicator].values.tolist()
82
+ # dataset[(pair,indicator)] = temp_indicator
83
+ # print('Succesfully add technical indicators')
84
+ # return dataset
85
+
86
+ def df_to_ary(self, pair_list, tech_indicator_list=None):
87
+ if tech_indicator_list is None:
88
+ tech_indicator_list = [
89
+ "macd",
90
+ "boll_ub",
91
+ "boll_lb",
92
+ "rsi_30",
93
+ "dx_30",
94
+ "close_30_sma",
95
+ "close_60_sma",
96
+ ]
97
+ df = self.dataframe
98
+ df = df.dropna()
99
+ date_ary = df.index.values
100
+ price_array = df[pd.MultiIndex.from_product([pair_list, ["close"]])].values
101
+ tech_array = df[
102
+ pd.MultiIndex.from_product([pair_list, tech_indicator_list])
103
+ ].values
104
+ return price_array, tech_array, date_ary
105
+
106
+ def min_ohlcv(self, dt, pair, limit):
107
+ since = calendar.timegm(dt.utctimetuple()) * 1000
108
+ return self.binance.fetch_ohlcv(
109
+ symbol=pair, timeframe="1m", since=since, limit=limit
110
+ )
111
+
112
+ def ohlcv(self, dt, pair, period="1d"):
113
+ ohlcv = []
114
+ limit = 1000
115
+ if period == "1Min":
116
+ limit = 720
117
+ elif period == "1D":
118
+ limit = 1
119
+ elif period == "1H":
120
+ limit = 24
121
+ elif period == "5Min":
122
+ limit = 288
123
+ for i in dt:
124
+ start_dt = i
125
+ since = calendar.timegm(start_dt.utctimetuple()) * 1000
126
+ if period == "1Min":
127
+ ohlcv.extend(self.min_ohlcv(start_dt, pair, limit))
128
+ else:
129
+ ohlcv.extend(
130
+ self.binance.fetch_ohlcv(
131
+ symbol=pair, timeframe=period, since=since, limit=limit
132
+ )
133
+ )
134
+ df = pd.DataFrame(
135
+ ohlcv, columns=["time", "open", "high", "low", "close", "volume"]
136
+ )
137
+ df["time"] = [datetime.fromtimestamp(float(time) / 1000) for time in df["time"]]
138
+ df["open"] = df["open"].astype(np.float64)
139
+ df["high"] = df["high"].astype(np.float64)
140
+ df["low"] = df["low"].astype(np.float64)
141
+ df["close"] = df["close"].astype(np.float64)
142
+ df["volume"] = df["volume"].astype(np.float64)
143
+ return df
finnlp/data_processors/fx.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import pandas as pd
5
+ from finta import TA
6
+
7
+
8
+ def add_time_feature(df, symbol, dt_col_name="time"):
9
+ """read csv into df and index on time
10
+ dt_col_name can be any unit from minutes to day. time is the index of pd
11
+ must have pd columns [(time_col),(asset_col), Open,close,High,Low,day]
12
+ data_process will add additional time information: time(index), minute, hour, weekday, week, month,year, day(since 1970)
13
+ use StopLoss and ProfitTaken to simplify the action,
14
+ feed a fixed StopLoss (SL = 200) and PT = SL * ratio
15
+ action space: [action[0,2],ratio[0,10]]
16
+ rewards is point
17
+
18
+ add hourly, dayofweek(0-6, Sun-Sat)
19
+ Args:
20
+ file (str): file path/name.csv
21
+ """
22
+
23
+ df["symbol"] = symbol
24
+ df["dt"] = pd.to_datetime(df[dt_col_name])
25
+ df.index = df["dt"]
26
+ df["minute"] = df["dt"].dt.minute
27
+ df["hour"] = df["dt"].dt.hour
28
+ df["weekday"] = df["dt"].dt.dayofweek
29
+ df["week"] = df["dt"].dt.isocalendar().week
30
+ df["month"] = df["dt"].dt.month
31
+ df["year"] = df["dt"].dt.year
32
+ df["day"] = df["dt"].dt.day
33
+ # df = df.set_index('dt')
34
+ return df
35
+
36
+
37
+ # 'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'dx_30','close_30_sma', 'close_60_sma'
38
+ def tech_indictors(df):
39
+ df["macd"] = TA.MACD(df).SIGNAL
40
+ df["boll_ub"] = TA.BBANDS(df).BB_UPPER
41
+ df["boll_lb"] = TA.BBANDS(df).BB_LOWER
42
+ df["rsi_30"] = TA.RSI(df, period=30)
43
+ df["dx_30"] = TA.ADX(df, period=30)
44
+ df["close_30_sma"] = TA.SMA(df, period=30)
45
+ df["close_60_sma"] = TA.SMA(df, period=60)
46
+
47
+ # fill NaN to 0
48
+ df = df.fillna(0)
49
+ print(
50
+ f"--------df head - tail ----------------\n{df.head(3)}\n{df.tail(3)}\n---------------------------------"
51
+ )
52
+
53
+ return df
54
+
55
+
56
+ def split_timeserious(df, key_ts="dt", freq="W", symbol=""):
57
+ """import df and split into hour, daily, weekly, monthly based and
58
+ save into subfolder
59
+
60
+ Args:
61
+ df (pandas df with timestamp is part of multi index):
62
+ spliter (str): H, D, W, M, Y
63
+ """
64
+
65
+ freq_name = {
66
+ "H": "hourly",
67
+ "D": "daily",
68
+ "W": "weekly",
69
+ "M": "monthly",
70
+ "Y": "Yearly",
71
+ }
72
+ for count, (n, g) in enumerate(df.groupby(pd.Grouper(level=key_ts, freq=freq))):
73
+ p = f"./data/split/{symbol}/{freq_name[freq]}"
74
+ os.makedirs(p, exist_ok=True)
75
+ # fname = f'{symbol}_{n:%Y%m%d}_{freq}_{count}.csv'
76
+ fname = f"{symbol}_{n:%Y}_{count}.csv"
77
+ fn = f"{p}/{fname}"
78
+ print(f"save to:{fn}")
79
+ g.reset_index(drop=True, inplace=True)
80
+ g.drop(columns=["dt"], inplace=True)
81
+ g.to_csv(fn)
82
+ return
83
+
84
+
85
+ """
86
+ python ./neo_finrl/data_processors/fx.py GBPUSD W ./data/raw/GBPUSD_raw.csv
87
+ symbol="GBPUSD"
88
+ freq = [H, D, W, M]
89
+ file .csv, column names [time, Open, High, Low, Close, Vol]
90
+ """
91
+ if __name__ == "__main__":
92
+ symbol, freq, file = sys.argv[1], sys.argv[2], sys.argv[3]
93
+ print(f"processing... symbol:{symbol} freq:{freq} file:{file}")
94
+ try:
95
+ df = pd.read_csv(file)
96
+ except Exception:
97
+ print(f"No such file or directory: {file}")
98
+ exit(0)
99
+ df = add_time_feature(df, symbol=symbol, dt_col_name="time")
100
+ df = tech_indictors(df)
101
+ split_timeserious(df, freq=freq, symbol=symbol)
102
+ print(f"Done!")
finnlp/data_processors/iexcloud.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+ from typing import List
4
+
5
+ import pandas as pd
6
+ import pandas_market_calendars as mcal
7
+ import pytz
8
+ import requests
9
+
10
+ from meta.data_processors._base import _Base
11
+
12
+ # from _base import _Base
13
+
14
+
15
+ class Iexcloud(_Base):
16
+ @classmethod
17
+ def _get_base_url(self, mode: str) -> str:
18
+ as1 = "mode must be sandbox or production."
19
+ assert mode in {"sandbox", "production"}, as1
20
+
21
+ if mode == "sandbox":
22
+ return "https://sandbox.iexapis.com"
23
+
24
+ return "https://cloud.iexapis.com"
25
+
26
+ def __init__(
27
+ self,
28
+ data_source: str,
29
+ start_date: str,
30
+ end_date: str,
31
+ time_interval: str,
32
+ **kwargs,
33
+ ):
34
+ super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
35
+ self.base_url = self._get_base_url(mode=kwargs["mode"])
36
+ self.token = kwargs["token"] or os.environ.get("IEX_TOKEN")
37
+
38
+ def download_data(
39
+ self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
40
+ ):
41
+ """Returns end of day historical data for up to 15 years.
42
+
43
+ Args:
44
+ ticker_list (List[str]): List of the tickers to retrieve information.
45
+ start_date (str): Oldest date of the range.
46
+ end_date (str): Latest date of the range.
47
+
48
+ Returns:
49
+ pd.DataFrame: A pandas dataframe with end of day historical data
50
+ for the specified tickers with the following columns:
51
+ date, tic, open, high, low, close, adjusted_close, volume.
52
+
53
+ Examples:
54
+ kwargs['mode'] = 'sandbox'
55
+ kwargs['token'] = 'Tsk_d633e2ff10d463...'
56
+ >>> iex_dloader = Iexcloud(data_source='iexcloud', **kwargs)
57
+ >>> iex_dloader.download_data(ticker_list=["AAPL", "NVDA"],
58
+ start_date='2014-01-01',
59
+ end_date='2021-12-12',
60
+ time_interval = '1D')
61
+ """
62
+ assert self.time_interval == "1D" # one day
63
+
64
+ price_data = pd.DataFrame()
65
+
66
+ query_params = {
67
+ "token": self.token,
68
+ }
69
+
70
+ if self.start_date and self.end_date:
71
+ query_params["from"] = self.start_date
72
+ query_params["to"] = self.end_date
73
+
74
+ for stock in ticker_list:
75
+ end_point = f"{self.base_url}/stable/time-series/HISTORICAL_PRICES/{stock}"
76
+
77
+ response = requests.get(
78
+ url=end_point,
79
+ params=query_params,
80
+ )
81
+ if response.status_code != 200:
82
+ raise requests.exceptions.RequestException(response.text)
83
+
84
+ temp = pd.DataFrame.from_dict(data=response.json())
85
+ temp["ticker"] = stock
86
+ price_data = price_data.append(temp)
87
+ price_data = price_data[
88
+ [
89
+ "date",
90
+ "ticker",
91
+ "open",
92
+ "high",
93
+ "low",
94
+ "close",
95
+ "fclose",
96
+ "volume",
97
+ ]
98
+ ]
99
+ price_data = price_data.rename(
100
+ columns={
101
+ "ticker": "tic",
102
+ "date": "time",
103
+ "fclose": "adjusted_close",
104
+ }
105
+ )
106
+
107
+ price_data.date = price_data.date.map(
108
+ lambda x: datetime.fromtimestamp(x / 1000, pytz.UTC).strftime("%Y-%m-%d")
109
+ )
110
+
111
+ self.dataframe = price_data
112
+
113
+ self.save_data(save_path)
114
+
115
+ print(
116
+ f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
117
+ )
118
+
119
+ def get_trading_days(self, start: str, end: str) -> List[str]:
120
+ """Retrieves every trading day between two dates.
121
+
122
+ Args:
123
+ start (str): Oldest date of the range.
124
+ end (str): Latest date of the range.
125
+
126
+ Returns:
127
+ List[str]: List of all trading days in YYYY-dd-mm format.
128
+
129
+ Examples:
130
+ >>> iex_dloader = Iexcloud(data_source='iexcloud',
131
+ mode='sandbox',
132
+ token='Tsk_d633e2ff10d463...')
133
+ >>> iex_dloader.get_trading_days(start='2014-01-01',
134
+ end='2021-12-12')
135
+ ['2021-12-15', '2021-12-16', '2021-12-17']
136
+ """
137
+ nyse = mcal.get_calendar("NYSE")
138
+
139
+ df = nyse.schedule(
140
+ start_date=pd.Timestamp(start, tz=pytz.UTC),
141
+ end_date=pd.Timestamp(end, tz=pytz.UTC),
142
+ )
143
+ return df.applymap(lambda x: x.strftime("%Y-%m-%d")).market_open.to_list()
finnlp/data_processors/joinquant.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import datetime
3
+ import os
4
+ from typing import List
5
+
6
+ import jqdatasdk as jq
7
+ import numpy as np
8
+
9
+ from meta.data_processors._base import _Base
10
+
11
+
12
+ class Joinquant(_Base):
13
+ def __init__(
14
+ self,
15
+ data_source: str,
16
+ start_date: str,
17
+ end_date: str,
18
+ time_interval: str,
19
+ **kwargs,
20
+ ):
21
+ super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
22
+ if "username" in kwargs.keys() and "password" in kwargs.keys():
23
+ jq.auth(kwargs["username"], kwargs["password"])
24
+
25
+ def download_data(
26
+ self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
27
+ ):
28
+ # joinquant supports: '1m', '5m', '15m', '30m', '60m', '120m', '1d', '1w', '1M'。'1w' denotes one week,‘1M' denotes one month。
29
+ count = len(self.get_trading_days(self.start_date, self.end_date))
30
+ df = jq.get_bars(
31
+ security=ticker_list,
32
+ count=count,
33
+ unit=self.time_interval,
34
+ fields=["date", "open", "high", "low", "close", "volume"],
35
+ end_dt=self.end_date,
36
+ )
37
+ df = df.reset_index().rename(columns={"level_0": "tic"})
38
+ self.dataframe = df
39
+
40
+ self.save_data(save_path)
41
+
42
+ print(
43
+ f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
44
+ )
45
+
46
+ def preprocess(df, stock_list):
47
+ n = len(stock_list)
48
+ N = df.shape[0]
49
+ assert N % n == 0
50
+ d = int(N / n)
51
+ stock1_ary = df.iloc[0:d, 1:].values
52
+ temp_ary = stock1_ary
53
+ for j in range(1, n):
54
+ stocki_ary = df.iloc[j * d : (j + 1) * d, 1:].values
55
+ temp_ary = np.hstack((temp_ary, stocki_ary))
56
+ return temp_ary
57
+
58
+ # start_day: str
59
+ # end_day: str
60
+ # output: list of str_of_trade_day, e.g., ['2021-09-01', '2021-09-02']
61
+ def get_trading_days(self, start_day: str, end_day: str) -> List[str]:
62
+ dates = jq.get_trade_days(start_day, end_day)
63
+ str_dates = []
64
+ for d in dates:
65
+ tmp = datetime.date.strftime(d, "%Y-%m-%d")
66
+ str_dates.append(tmp)
67
+ # str_dates = [date2str(dt) for dt in dates]
68
+ return str_dates
finnlp/data_processors/quandl.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import pytz
6
+ import quandl
7
+ import yfinance as yf
8
+
9
+ """Reference: https://github.com/AI4Finance-LLC/FinRL"""
10
+
11
+ try:
12
+ import exchange_calendars as tc
13
+ except:
14
+ print(
15
+ "Cannot import exchange_calendars.",
16
+ "If you are using python>=3.7, please install it.",
17
+ )
18
+ import trading_calendars as tc
19
+
20
+ print("Use trading_calendars instead for yahoofinance processor..")
21
+ # from basic_processor import _Base
22
+ from meta.data_processors._base import _Base
23
+ from meta.data_processors._base import calc_time_zone
24
+
25
+ from meta.config import (
26
+ TIME_ZONE_SHANGHAI,
27
+ TIME_ZONE_USEASTERN,
28
+ TIME_ZONE_PARIS,
29
+ TIME_ZONE_BERLIN,
30
+ TIME_ZONE_JAKARTA,
31
+ TIME_ZONE_SELFDEFINED,
32
+ USE_TIME_ZONE_SELFDEFINED,
33
+ BINANCE_BASE_URL,
34
+ )
35
+
36
+ TIME_ZONE_SELFDEFINED = TIME_ZONE_USEASTERN # If neither of the above is your time zone, you should define it, and set USE_TIME_ZONE_SELFDEFINED 1.
37
+ USE_TIME_ZONE_SELFDEFINED = 1 # 0 (default) or 1 (use the self defined)
38
+
39
+
40
+ class Quandl(_Base):
41
+ def __init__(
42
+ self,
43
+ data_source: str,
44
+ start_date: str,
45
+ end_date: str,
46
+ time_interval: str,
47
+ **kwargs,
48
+ ):
49
+ super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
50
+
51
+ def download_data(
52
+ self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
53
+ ):
54
+ self.time_zone = calc_time_zone(
55
+ ticker_list, TIME_ZONE_SELFDEFINED, USE_TIME_ZONE_SELFDEFINED
56
+ )
57
+
58
+ # Download and save the data in a pandas DataFrame:
59
+ # data_df = pd.DataFrame()
60
+ # # set paginate to True because Quandl limits tables API to 10,000 rows per call
61
+ # data = quandl.get_table('ZACKS/FC', paginate=True, ticker=ticker_list, per_end_date={'gte': '2021-09-01'}, qopts={'columns': ['ticker', 'per_end_date']})
62
+ # data = quandl.get('ZACKS/FC', ticker=ticker_list, start_date="2020-12-31", end_date="2021-12-31")
63
+ self.dataframe = quandl.get_table(
64
+ "ZACKS/FC",
65
+ ticker=ticker_list,
66
+ qopts={"columns": ["ticker", "date", "adjusted_close"]},
67
+ date={"gte": self.start_date, "lte": self.end_date},
68
+ paginate=True,
69
+ )
70
+ self.dataframe.dropna(inplace=True)
71
+ self.dataframe.reset_index(drop=True, inplace=True)
72
+ print("Shape of DataFrame: ", self.dataframe.shape)
73
+ # print("Display DataFrame: ", data_df.head())
74
+
75
+ self.dataframe.sort_values(by=["date", "ticker"], inplace=True)
76
+ self.dataframe.reset_index(drop=True, inplace=True)
77
+
78
+ self.save_data(save_path)
79
+
80
+ print(
81
+ f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
82
+ )
83
+
84
+ # def get_trading_days(self, start, end):
85
+ #
finnlp/data_processors/quantconnect.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from meta.config import BINANCE_BASE_URL
4
+ from meta.config import TIME_ZONE_BERLIN
5
+ from meta.config import TIME_ZONE_JAKARTA
6
+ from meta.config import TIME_ZONE_PARIS
7
+ from meta.config import TIME_ZONE_SELFDEFINED
8
+ from meta.config import TIME_ZONE_SHANGHAI
9
+ from meta.config import TIME_ZONE_USEASTERN
10
+ from meta.config import USE_TIME_ZONE_SELFDEFINED
11
+ from meta.data_processors._base import _Base
12
+
13
+ # from basic_processor import _Base
14
+
15
+
16
+ ## The code of this file is used in website, not locally.
17
+ class Quantconnect(_Base):
18
+ def __init__(
19
+ self,
20
+ data_source: str,
21
+ start_date: str,
22
+ end_date: str,
23
+ time_interval: str,
24
+ **kwargs,
25
+ ):
26
+ super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
27
+
28
+ # def data_fetch(start_time, end_time, stock_list, resolution=Resolution.Daily) :
29
+ # #resolution: Daily, Hour, Minute, Second
30
+ # qb = QuantBook()
31
+ # for stock in stock_list:
32
+ # qb.AddEquity(stock)
33
+ # history = qb.History(qb.Securities.Keys, start_time, end_time, resolution)
34
+ # return history
35
+
36
+ def download_data(
37
+ self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
38
+ ):
39
+ # self.time_zone = calc_time_zone(ticker_list, TIME_ZONE_SELFDEFINED, USE_TIME_ZONE_SELFDEFINED)
40
+
41
+ # start_date = pd.Timestamp(start_date, tz=self.time_zone)
42
+ # end_date = pd.Timestamp(end_date, tz=self.time_zone) + pd.Timedelta(days=1)
43
+ qb = QuantBook()
44
+ for stock in ticker_list:
45
+ qb.AddEquity(stock)
46
+ history = qb.History(
47
+ qb.Securities.Keys,
48
+ self.start_date,
49
+ self.end_date,
50
+ self.time_interval,
51
+ )
52
+ self.dataframe = history
53
+
54
+ self.save_data(save_path)
55
+
56
+ print(
57
+ f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
58
+ )
59
+
60
+ # def preprocess(df, stock_list):
61
+ # df = df[['open','high','low','close','volume']]
62
+ # if_first_time = True
63
+ # for stock in stock_list:
64
+ # if if_first_time:
65
+ # ary = df.loc[stock].values
66
+ # if_first_time = False
67
+ # else:
68
+ # temp = df.loc[stock].values
69
+ # ary = np.hstack((ary,temp))
70
+ # return ary
finnlp/data_processors/ricequant.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import rqdatac as ricequant
4
+
5
+ from meta.data_processors._base import _Base
6
+
7
+
8
+ class Ricequant(_Base):
9
+ def __init__(
10
+ self,
11
+ data_source: str,
12
+ start_date: str,
13
+ end_date: str,
14
+ time_interval: str,
15
+ **kwargs,
16
+ ):
17
+ super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
18
+ if kwargs["username"] is None or kwargs["password"] is None:
19
+ ricequant.init() # if the lisence is already set, you can init without username and password
20
+ else:
21
+ ricequant.init(
22
+ kwargs["username"], kwargs["password"]
23
+ ) # init with username and password
24
+
25
+ def download_data(
26
+ self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
27
+ ):
28
+ # download data by calling RiceQuant API
29
+ dataframe = ricequant.get_price(
30
+ ticker_list,
31
+ frequency=self.time_interval,
32
+ start_date=self.start_date,
33
+ end_date=self.end_date,
34
+ )
35
+ self.dataframe = dataframe
36
+
37
+ self.save_data(save_path)
38
+
39
+ print(
40
+ f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
41
+ )
42
+
43
+ # def clean_data(self, df) -> pd.DataFrame:
44
+ # ''' RiceQuant data is already cleaned, we only need to transform data format here.
45
+ # No need for filling NaN data'''
46
+ # df = df.copy()
47
+ # # raw df uses multi-index (tic,time), reset it to single index (time)
48
+ # df = df.reset_index(level=[0,1])
49
+ # # rename column order_book_id to tic
50
+ # df = df.rename(columns={'order_book_id':'tic', 'datetime':'time'})
51
+ # # reserve columns needed
52
+ # df = df[['tic','time','open','high','low','close','volume']]
53
+ # # check if there is NaN values
54
+ # assert not df.isnull().values.any()
55
+ # return df
56
+
57
+ # def add_vix(self, data):
58
+ # print('VIX is NOT applicable to China A-shares')
59
+ # return data
60
+
61
+ # def calculate_turbulence(self, data, time_period=252):
62
+ # # can add other market assets
63
+ # df = data.copy()
64
+ # df_price_pivot = df.pivot(index="date", columns="tic", values="close")
65
+ # # use returns to calculate turbulence
66
+ # df_price_pivot = df_price_pivot.pct_change()
67
+ #
68
+ # unique_date = df.date.unique()
69
+ # # start after a fixed time period
70
+ # start = time_period
71
+ # turbulence_index = [0] * start
72
+ # # turbulence_index = [0]
73
+ # count = 0
74
+ # for i in range(start, len(unique_date)):
75
+ # current_price = df_price_pivot[df_price_pivot.index == unique_date[i]]
76
+ # # use one year rolling window to calcualte covariance
77
+ # hist_price = df_price_pivot[
78
+ # (df_price_pivot.index < unique_date[i])
79
+ # & (df_price_pivot.index >= unique_date[i - time_period])
80
+ # ]
81
+ # # Drop tickers which has number missing values more than the "oldest" ticker
82
+ # filtered_hist_price = hist_price.iloc[hist_price.isna().sum().min():].dropna(axis=1)
83
+ #
84
+ # cov_temp = filtered_hist_price.cov()
85
+ # current_temp = current_price[[x for x in filtered_hist_price]] - np.mean(filtered_hist_price, axis=0)
86
+ # temp = current_temp.values.dot(np.linalg.pinv(cov_temp)).dot(
87
+ # current_temp.values.T
88
+ # )
89
+ # if temp > 0:
90
+ # count += 1
91
+ # if count > 2:
92
+ # turbulence_temp = temp[0][0]
93
+ # else:
94
+ # # avoid large outlier because of the calculation just begins
95
+ # turbulence_temp = 0
96
+ # else:
97
+ # turbulence_temp = 0
98
+ # turbulence_index.append(turbulence_temp)
99
+ #
100
+ # turbulence_index = pd.DataFrame(
101
+ # {"date": df_price_pivot.index, "turbulence": turbulence_index}
102
+ # )
103
+ # return turbulence_index
104
+ #
105
+ # def add_turbulence(self, data, time_period=252):
106
+ # """
107
+ # add turbulence index from a precalcualted dataframe
108
+ # :param data: (df) pandas dataframe
109
+ # :return: (df) pandas dataframe
110
+ # """
111
+ # df = data.copy()
112
+ # turbulence_index = self.calculate_turbulence(df, time_period=time_period)
113
+ # df = df.merge(turbulence_index, on="date")
114
+ # df = df.sort_values(["date", "tic"]).reset_index(drop=True)
115
+ # return df
116
+
117
+ # def df_to_array(self, df, tech_indicator_list, if_vix):
118
+ # df = df.copy()
119
+ # unique_ticker = df.tic.unique()
120
+ # if_first_time = True
121
+ # for tic in unique_ticker:
122
+ # if if_first_time:
123
+ # price_array = df[df.tic==tic][['close']].values
124
+ # tech_array = df[df.tic==tic][tech_indicator_list].values
125
+ # #risk_array = df[df.tic==tic]['turbulence'].values
126
+ # if_first_time = False
127
+ # else:
128
+ # price_array = np.hstack([price_array, df[df.tic==tic][['close']].values])
129
+ # tech_array = np.hstack([tech_array, df[df.tic==tic][tech_indicator_list].values])
130
+ # print('Successfully transformed into array')
131
+ # return price_array, tech_array, None
finnlp/data_processors/tushare.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import time
4
+ import warnings
5
+
6
+ warnings.filterwarnings("ignore")
7
+ from typing import List
8
+
9
+ import pandas as pd
10
+ from tqdm import tqdm
11
+ from matplotlib import pyplot as plt
12
+
13
+ import stockstats
14
+ import talib
15
+ from meta.data_processors._base import _Base
16
+
17
+ import tushare as ts
18
+
19
+
20
+ class Tushare(_Base):
21
+ """
22
+ key-value in kwargs
23
+ ----------
24
+ token : str
25
+ get from https://waditu.com/ after registration
26
+ adj: str
27
+ Whether to use adjusted closing price. Default is None.
28
+ If you want to use forward adjusted closing price or 前复权. pleses use 'qfq'
29
+ If you want to use backward adjusted closing price or 后复权. pleses use 'hfq'
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ data_source: str,
35
+ start_date: str,
36
+ end_date: str,
37
+ time_interval: str,
38
+ **kwargs,
39
+ ):
40
+ super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
41
+ assert "token" in kwargs.keys(), "Please input token!"
42
+ self.token = kwargs["token"]
43
+ if "adj" in kwargs.keys():
44
+ self.adj = kwargs["adj"]
45
+ print(f"Using {self.adj} method.")
46
+ else:
47
+ self.adj = None
48
+
49
+ def get_data(self, id) -> pd.DataFrame:
50
+ # df1 = ts.pro_bar(ts_code=id, start_date=self.start_date,end_date='20180101')
51
+ # dfb=pd.concat([df, df1], ignore_index=True)
52
+ # print(dfb.shape)
53
+ return ts.pro_bar(
54
+ ts_code=id,
55
+ start_date=self.start_date,
56
+ end_date=self.end_date,
57
+ adj=self.adj,
58
+ )
59
+
60
+ def download_data(
61
+ self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
62
+ ):
63
+ """
64
+ `pd.DataFrame`
65
+ 7 columns: A tick symbol, time, open, high, low, close and volume
66
+ for the specified stock ticker
67
+ """
68
+ assert self.time_interval == "1d", "Not supported currently"
69
+
70
+ self.ticker_list = ticker_list
71
+ ts.set_token(self.token)
72
+
73
+ self.dataframe = pd.DataFrame()
74
+ for i in tqdm(ticker_list, total=len(ticker_list)):
75
+ # nonstandard_id = self.transfer_standard_ticker_to_nonstandard(i)
76
+ # df_temp = self.get_data(nonstandard_id)
77
+ df_temp = self.get_data(i)
78
+ self.dataframe = self.dataframe.append(df_temp)
79
+ # print("{} ok".format(i))
80
+ time.sleep(0.25)
81
+
82
+ self.dataframe.columns = [
83
+ "tic",
84
+ "time",
85
+ "open",
86
+ "high",
87
+ "low",
88
+ "close",
89
+ "pre_close",
90
+ "change",
91
+ "pct_chg",
92
+ "volume",
93
+ "amount",
94
+ ]
95
+ self.dataframe.sort_values(by=["time", "tic"], inplace=True)
96
+ self.dataframe.reset_index(drop=True, inplace=True)
97
+
98
+ self.dataframe = self.dataframe[
99
+ ["tic", "time", "open", "high", "low", "close", "volume"]
100
+ ]
101
+ # self.dataframe.loc[:, 'tic'] = pd.DataFrame((self.dataframe['tic'].tolist()))
102
+ self.dataframe["time"] = pd.to_datetime(self.dataframe["time"], format="%Y%m%d")
103
+ self.dataframe["day"] = self.dataframe["time"].dt.dayofweek
104
+ self.dataframe["time"] = self.dataframe.time.apply(
105
+ lambda x: x.strftime("%Y-%m-%d")
106
+ )
107
+
108
+ self.dataframe.dropna(inplace=True)
109
+ self.dataframe.sort_values(by=["time", "tic"], inplace=True)
110
+ self.dataframe.reset_index(drop=True, inplace=True)
111
+
112
+ self.save_data(save_path)
113
+
114
+ print(
115
+ f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
116
+ )
117
+
118
+ def data_split(self, df, start, end, target_date_col="time"):
119
+ """
120
+ split the dataset into training or testing using time
121
+ :param data: (df) pandas dataframe, start, end
122
+ :return: (df) pandas dataframe
123
+ """
124
+ data = df[(df[target_date_col] >= start) & (df[target_date_col] < end)]
125
+ data = data.sort_values([target_date_col, "tic"], ignore_index=True)
126
+ data.index = data[target_date_col].factorize()[0]
127
+ return data
128
+
129
+ def transfer_standard_ticker_to_nonstandard(self, ticker: str) -> str:
130
+ # "600000.XSHG" -> "600000.SH"
131
+ # "000612.XSHE" -> "000612.SZ"
132
+ n, alpha = ticker.split(".")
133
+ assert alpha in ["XSHG", "XSHE"], "Wrong alpha"
134
+ if alpha == "XSHG":
135
+ nonstandard_ticker = n + ".SH"
136
+ elif alpha == "XSHE":
137
+ nonstandard_ticker = n + ".SZ"
138
+ return nonstandard_ticker
139
+
140
+ def save_data(self, path):
141
+ if ".csv" in path:
142
+ path = path.split("/")
143
+ filename = path[-1]
144
+ path = "/".join(path[:-1] + [""])
145
+ else:
146
+ if path[-1] == "/":
147
+ filename = "dataset.csv"
148
+ else:
149
+ filename = "/dataset.csv"
150
+
151
+ os.makedirs(path, exist_ok=True)
152
+ self.dataframe.to_csv(path + filename, index=False)
153
+
154
+ def load_data(self, path):
155
+ assert ".csv" in path # only support csv format now
156
+ self.dataframe = pd.read_csv(path)
157
+ columns = self.dataframe.columns
158
+ assert (
159
+ "tic" in columns and "time" in columns and "close" in columns
160
+ ) # input file must have "tic","time" and "close" columns
161
+
162
+
163
+ class ReturnPlotter:
164
+ """
165
+ An easy-to-use plotting tool to plot cumulative returns over time.
166
+ Baseline supports equal weighting(default) and any stocks you want to use for comparison.
167
+ """
168
+
169
+ def __init__(self, df_account_value, df_trade, start_date, end_date):
170
+ self.start = start_date
171
+ self.end = end_date
172
+ self.trade = df_trade
173
+ self.df_account_value = df_account_value
174
+
175
+ def get_baseline(self, ticket):
176
+ df = ts.get_hist_data(ticket, start=self.start, end=self.end)
177
+ df.loc[:, "dt"] = df.index
178
+ df.index = range(len(df))
179
+ df.sort_values(axis=0, by="dt", ascending=True, inplace=True)
180
+ df["time"] = pd.to_datetime(df["dt"], format="%Y-%m-%d")
181
+ return df
182
+
183
+ def plot(self, baseline_ticket=None):
184
+ """
185
+ Plot cumulative returns over time.
186
+ use baseline_ticket to specify stock you want to use for comparison
187
+ (default: equal weighted returns)
188
+ """
189
+ baseline_label = "Equal-weight portfolio"
190
+ tic2label = {"399300": "CSI 300 Index", "000016": "SSE 50 Index"}
191
+ if baseline_ticket:
192
+ # 使用指定ticket作为baseline
193
+ baseline_df = self.get_baseline(baseline_ticket)
194
+ baseline_date_list = baseline_df.time.dt.strftime("%Y-%m-%d").tolist()
195
+ df_date_list = self.df_account_value.time.tolist()
196
+ df_account_value = self.df_account_value[
197
+ self.df_account_value.time.isin(baseline_date_list)
198
+ ]
199
+ baseline_df = baseline_df[baseline_df.time.isin(df_date_list)]
200
+ baseline = baseline_df.close.tolist()
201
+ baseline_label = tic2label.get(baseline_ticket, baseline_ticket)
202
+ ours = df_account_value.account_value.tolist()
203
+ else:
204
+ # 均等权重
205
+ all_date = self.trade.time.unique().tolist()
206
+ baseline = []
207
+ for day in all_date:
208
+ day_close = self.trade[self.trade["time"] == day].close.tolist()
209
+ avg_close = sum(day_close) / len(day_close)
210
+ baseline.append(avg_close)
211
+ ours = self.df_account_value.account_value.tolist()
212
+
213
+ ours = self.pct(ours)
214
+ baseline = self.pct(baseline)
215
+
216
+ days_per_tick = (
217
+ 60 # you should scale this variable accroding to the total trading days
218
+ )
219
+ time = list(range(len(ours)))
220
+ datetimes = self.df_account_value.time.tolist()
221
+ ticks = [tick for t, tick in zip(time, datetimes) if t % days_per_tick == 0]
222
+ plt.title("Cumulative Returns")
223
+ plt.plot(time, ours, label="DDPG Agent", color="green")
224
+ plt.plot(time, baseline, label=baseline_label, color="grey")
225
+ plt.xticks([i * days_per_tick for i in range(len(ticks))], ticks, fontsize=7)
226
+
227
+ plt.xlabel("Date")
228
+ plt.ylabel("Cumulative Return")
229
+
230
+ plt.legend()
231
+ plt.show()
232
+ plt.savefig(f"plot_{baseline_ticket}.png")
233
+
234
+ def plot_all(self):
235
+ baseline_label = "Equal-weight portfolio"
236
+ tic2label = {"399300": "CSI 300 Index", "000016": "SSE 50 Index"}
237
+
238
+ # time lists
239
+ # algorithm time list
240
+ df_date_list = self.df_account_value.time.tolist()
241
+
242
+ # 399300 time list
243
+ csi300_df = self.get_baseline("399300")
244
+ csi300_date_list = csi300_df.time.dt.strftime("%Y-%m-%d").tolist()
245
+
246
+ # 000016 time list
247
+ sh50_df = self.get_baseline("000016")
248
+ sh50_date_list = sh50_df.time.dt.strftime("%Y-%m-%d").tolist()
249
+
250
+ # find intersection
251
+ all_date = sorted(
252
+ list(set(df_date_list) & set(csi300_date_list) & set(sh50_date_list))
253
+ )
254
+
255
+ # filter data
256
+ csi300_df = csi300_df[csi300_df.time.isin(all_date)]
257
+ baseline_300 = csi300_df.close.tolist()
258
+ baseline_label_300 = tic2label["399300"]
259
+
260
+ sh50_df = sh50_df[sh50_df.time.isin(all_date)]
261
+ baseline_50 = sh50_df.close.tolist()
262
+ baseline_label_50 = tic2label["000016"]
263
+
264
+ # 均等权重
265
+ baseline_equal_weight = []
266
+ for day in all_date:
267
+ day_close = self.trade[self.trade["time"] == day].close.tolist()
268
+ avg_close = sum(day_close) / len(day_close)
269
+ baseline_equal_weight.append(avg_close)
270
+
271
+ df_account_value = self.df_account_value[
272
+ self.df_account_value.time.isin(all_date)
273
+ ]
274
+ ours = df_account_value.account_value.tolist()
275
+
276
+ ours = self.pct(ours)
277
+ baseline_300 = self.pct(baseline_300)
278
+ baseline_50 = self.pct(baseline_50)
279
+ baseline_equal_weight = self.pct(baseline_equal_weight)
280
+
281
+ days_per_tick = (
282
+ 60 # you should scale this variable accroding to the total trading days
283
+ )
284
+ time = list(range(len(ours)))
285
+ datetimes = self.df_account_value.time.tolist()
286
+ ticks = [tick for t, tick in zip(time, datetimes) if t % days_per_tick == 0]
287
+ plt.title("Cumulative Returns")
288
+ plt.plot(time, ours, label="DDPG Agent", color="darkorange")
289
+ plt.plot(
290
+ time,
291
+ baseline_equal_weight,
292
+ label=baseline_label,
293
+ color="cornflowerblue",
294
+ ) # equal weight
295
+ plt.plot(
296
+ time, baseline_300, label=baseline_label_300, color="lightgreen"
297
+ ) # 399300
298
+ plt.plot(time, baseline_50, label=baseline_label_50, color="silver") # 000016
299
+ plt.xlabel("Date")
300
+ plt.ylabel("Cumulative Return")
301
+
302
+ plt.xticks([i * days_per_tick for i in range(len(ticks))], ticks, fontsize=7)
303
+ plt.legend()
304
+ plt.show()
305
+ plt.savefig("./plot_all.png")
306
+
307
+ def pct(self, l):
308
+ """Get percentage"""
309
+ base = l[0]
310
+ return [x / base for x in l]
311
+
312
+ def get_return(self, df, value_col_name="account_value"):
313
+ df = copy.deepcopy(df)
314
+ df["daily_return"] = df[value_col_name].pct_change(1)
315
+ df["time"] = pd.to_datetime(df["time"], format="%Y-%m-%d")
316
+ df.set_index("time", inplace=True, drop=True)
317
+ df.index = df.index.tz_localize("UTC")
318
+ return pd.Series(df["daily_return"], index=df.index)
finnlp/data_processors/wrds.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import pytz
7
+ import wrds
8
+
9
+ try:
10
+ import exchange_calendars as tc
11
+ except:
12
+ print(
13
+ "Cannot import exchange_calendars.",
14
+ "If you are using python>=3.7, please install it.",
15
+ )
16
+ import trading_calendars as tc
17
+
18
+ print("Use trading_calendars instead for wrds processor.")
19
+ # from basic_processor import _Base
20
+ from meta.data_processors._base import _Base
21
+
22
+ pd.options.mode.chained_assignment = None
23
+
24
+
25
+ class Wrds(_Base):
26
+ # def __init__(self,if_offline=False):
27
+ # if not if_offline:
28
+ # self.db = wrds.Connection()
29
+ def __init__(
30
+ self,
31
+ data_source: str,
32
+ start_date: str,
33
+ end_date: str,
34
+ time_interval: str,
35
+ **kwargs,
36
+ ):
37
+ super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
38
+ if "if_offline" in kwargs.keys() and not kwargs["if_offline"]:
39
+ self.db = wrds.Connection()
40
+
41
+ def download_data(
42
+ self,
43
+ ticker_list: List[str],
44
+ if_save_tempfile=False,
45
+ filter_shares=0,
46
+ save_path: str = "./data/dataset.csv",
47
+ ):
48
+ dates = self.get_trading_days(self.start_date, self.end_date)
49
+ print("Trading days: ")
50
+ print(dates)
51
+ first_time = True
52
+ empty = True
53
+ stock_set = tuple(ticker_list)
54
+ for i in dates:
55
+ x = self.data_fetch_wrds(i, stock_set, filter_shares, self.time_interval)
56
+
57
+ if not x[1]:
58
+ empty = False
59
+ dataset = x[0]
60
+ dataset = self.preprocess_to_ohlcv(
61
+ dataset, time_interval=(str(self.time_interval) + "S")
62
+ )
63
+ print("Data for date: " + i + " finished")
64
+ if first_time:
65
+ temp = dataset
66
+ first_time = False
67
+ else:
68
+ temp = pd.concat([temp, dataset])
69
+ if if_save_tempfile:
70
+ temp.to_csv("./temp.csv")
71
+ if empty:
72
+ raise ValueError("Empty Data under input parameters!")
73
+ result = temp
74
+ result = result.sort_values(by=["time", "tic"])
75
+ result = result.reset_index(drop=True)
76
+ self.dataframe = result
77
+
78
+ self.save_data(save_path)
79
+
80
+ print(
81
+ f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
82
+ )
83
+
84
+ def preprocess_to_ohlcv(self, df, time_interval="60S"):
85
+ df = df[["date", "time_m", "sym_root", "size", "price"]]
86
+ tic_list = np.unique(df["sym_root"].values)
87
+ final_df = None
88
+ first_time = True
89
+ for i in range(len(tic_list)):
90
+ tic = tic_list[i]
91
+ time_list = []
92
+ temp_df = df[df["sym_root"] == tic]
93
+ for i in range(temp_df.shape[0]):
94
+ date = temp_df["date"].iloc[i]
95
+ time_m = temp_df["time_m"].iloc[i]
96
+ time = str(date) + " " + str(time_m)
97
+ try:
98
+ time = datetime.datetime.strptime(time, "%Y-%m-%d %H:%M:%S.%f")
99
+ except:
100
+ time = datetime.datetime.strptime(time, "%Y-%m-%d %H:%M:%S")
101
+ time_list.append(time)
102
+ temp_df["time"] = time_list
103
+ temp_df = temp_df.set_index("time")
104
+ data_ohlc = temp_df["price"].resample(time_interval).ohlc()
105
+ data_v = temp_df["size"].resample(time_interval).agg({"size": "sum"})
106
+ volume = data_v["size"].values
107
+ data_ohlc["volume"] = volume
108
+ data_ohlc["tic"] = tic
109
+ if first_time:
110
+ final_df = data_ohlc.reset_index()
111
+ first_time = False
112
+ else:
113
+ final_df = final_df.append(data_ohlc.reset_index(), ignore_index=True)
114
+ return final_df
115
+
116
+ def clean_data(self):
117
+ df = self.dataframe[["time", "open", "high", "low", "close", "volume", "tic"]]
118
+ # remove 16:00 data
119
+ tic_list = np.unique(df["tic"].values)
120
+ ary = df.values
121
+ rows_1600 = []
122
+ for i in range(ary.shape[0]):
123
+ row = ary[i]
124
+ time = row[0]
125
+ if str(time)[-8:] == "16:00:00":
126
+ rows_1600.append(i)
127
+
128
+ df = df.drop(rows_1600)
129
+ df = df.sort_values(by=["tic", "time"])
130
+
131
+ # check missing rows
132
+ tic_dic = {tic: [0, 0] for tic in tic_list}
133
+ ary = df.values
134
+ for i in range(ary.shape[0]):
135
+ row = ary[i]
136
+ volume = row[5]
137
+ tic = row[6]
138
+ if volume != 0:
139
+ tic_dic[tic][0] += 1
140
+ tic_dic[tic][1] += 1
141
+ constant = np.unique(df["time"].values).shape[0]
142
+ nan_tics = [tic for tic, value in tic_dic.items() if value[1] != constant]
143
+ # fill missing rows
144
+ normal_time = np.unique(df["time"].values)
145
+
146
+ df2 = df.copy()
147
+ for tic in nan_tics:
148
+ tic_time = df[df["tic"] == tic]["time"].values
149
+ missing_time = [i for i in normal_time if i not in tic_time]
150
+ for time in missing_time:
151
+ temp_df = pd.DataFrame(
152
+ [[time, np.nan, np.nan, np.nan, np.nan, 0, tic]],
153
+ columns=[
154
+ "time",
155
+ "open",
156
+ "high",
157
+ "low",
158
+ "close",
159
+ "volume",
160
+ "tic",
161
+ ],
162
+ )
163
+ df2 = df2.append(temp_df, ignore_index=True)
164
+
165
+ # fill nan data
166
+ df = df2.sort_values(by=["tic", "time"])
167
+ for i in range(df.shape[0]):
168
+ if float(df.iloc[i]["volume"]) == 0:
169
+ previous_close = df.iloc[i - 1]["close"]
170
+ if str(previous_close) == "nan":
171
+ raise ValueError("Error nan price")
172
+ df.iloc[i, 1] = previous_close
173
+ df.iloc[i, 2] = previous_close
174
+ df.iloc[i, 3] = previous_close
175
+ df.iloc[i, 4] = previous_close
176
+ # check if nan
177
+ ary = df[["open", "high", "low", "close", "volume"]].values
178
+ assert np.isnan(np.min(ary)) == False
179
+ # final preprocess
180
+ df = df[["time", "open", "high", "low", "close", "volume", "tic"]]
181
+ df = df.reset_index(drop=True)
182
+ print("Data clean finished")
183
+ self.dataframe = df
184
+
185
+ def get_trading_days(self, start, end):
186
+ nyse = tc.get_calendar("NYSE")
187
+ df = nyse.sessions_in_range(
188
+ pd.Timestamp(start, tz=pytz.UTC), pd.Timestamp(end, tz=pytz.UTC)
189
+ )
190
+ return [str(day)[:10] for day in df]
191
+
192
+ def data_fetch_wrds(
193
+ self,
194
+ date="2021-05-01",
195
+ stock_set=("AAPL"),
196
+ filter_shares=0,
197
+ time_interval=60,
198
+ ):
199
+ # start_date, end_date should be in the same year
200
+ current_date = datetime.datetime.strptime(date, "%Y-%m-%d")
201
+ lib = "taqm_" + str(current_date.year) # taqm_2021
202
+ table = "ctm_" + current_date.strftime("%Y%m%d") # ctm_20210501
203
+
204
+ parm = {"syms": stock_set, "num_shares": filter_shares}
205
+ try:
206
+ data = self.db.raw_sql(
207
+ "select * from "
208
+ + lib
209
+ + "."
210
+ + table
211
+ + " where sym_root in %(syms)s and time_m between '9:30:00' and '16:00:00' and size > %(num_shares)s and sym_suffix is null",
212
+ params=parm,
213
+ )
214
+ if_empty = False
215
+ return data, if_empty
216
+ except:
217
+ print("Data for date: " + date + " error")
218
+ if_empty = True
219
+ return None, if_empty
220
+
221
+ # def add_technical_indicator(self, df, tech_indicator_list = [
222
+ # 'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'dx_30',
223
+ # 'close_30_sma', 'close_60_sma']):
224
+ # df = df.rename(columns={'time':'date'})
225
+ # df = df.copy()
226
+ # df = df.sort_values(by=['tic', 'date'])
227
+ # stock = Sdf.retype(df.copy())
228
+ # unique_ticker = stock.tic.unique()
229
+ # tech_indicator_list = tech_indicator_list
230
+ #
231
+ # for indicator in tech_indicator_list:
232
+ # indicator_df = pd.DataFrame()
233
+ # for i in range(len(unique_ticker)):
234
+ # # print(unique_ticker[i], i)
235
+ # temp_indicator = stock[stock.tic == unique_ticker[i]][indicator]
236
+ # temp_indicator = pd.DataFrame(temp_indicator)
237
+ # temp_indicator['tic'] = unique_ticker[i]
238
+ # # print(len(df[df.tic == unique_ticker[i]]['date'].to_list()))
239
+ # temp_indicator['date'] = df[df.tic == unique_ticker[i]]['date'].to_list()
240
+ # indicator_df = indicator_df.append(
241
+ # temp_indicator, ignore_index=True
242
+ # )
243
+ # df = df.merge(indicator_df[['tic', 'date', indicator]], on=['tic', 'date'], how='left')
244
+ # df = df.sort_values(by=['date', 'tic'])
245
+ # print('Succesfully add technical indicators')
246
+ # return df
247
+
248
+ # def calculate_turbulence(self,data, time_period=252):
249
+ # # can add other market assets
250
+ # df = data.copy()
251
+ # df_price_pivot = df.pivot(index="date", columns="tic", values="close")
252
+ # # use returns to calculate turbulence
253
+ # df_price_pivot = df_price_pivot.pct_change()
254
+ #
255
+ # unique_date = df.date.unique()
256
+ # # start after a fixed time period
257
+ # start = time_period
258
+ # turbulence_index = [0] * start
259
+ # # turbulence_index = [0]
260
+ # count = 0
261
+ # for i in range(start, len(unique_date)):
262
+ # current_price = df_price_pivot[df_price_pivot.index == unique_date[i]]
263
+ # # use one year rolling window to calcualte covariance
264
+ # hist_price = df_price_pivot[
265
+ # (df_price_pivot.index < unique_date[i])
266
+ # & (df_price_pivot.index >= unique_date[i - time_period])
267
+ # ]
268
+ # # Drop tickers which has number missing values more than the "oldest" ticker
269
+ # filtered_hist_price = hist_price.iloc[hist_price.isna().sum().min():].dropna(axis=1)
270
+ #
271
+ # cov_temp = filtered_hist_price.cov()
272
+ # current_temp = current_price[[x for x in filtered_hist_price]] - np.mean(filtered_hist_price, axis=0)
273
+ # temp = current_temp.values.dot(np.linalg.pinv(cov_temp)).dot(
274
+ # current_temp.values.T
275
+ # )
276
+ # if temp > 0:
277
+ # count += 1
278
+ # if count > 2:
279
+ # turbulence_temp = temp[0][0]
280
+ # else:
281
+ # # avoid large outlier because of the calculation just begins
282
+ # turbulence_temp = 0
283
+ # else:
284
+ # turbulence_temp = 0
285
+ # turbulence_index.append(turbulence_temp)
286
+ #
287
+ # turbulence_index = pd.DataFrame(
288
+ # {"date": df_price_pivot.index, "turbulence": turbulence_index}
289
+ # )
290
+ # return turbulence_index
291
+ #
292
+ # def add_turbulence(self,data, time_period=252):
293
+ # """
294
+ # add turbulence index from a precalcualted dataframe
295
+ # :param data: (df) pandas dataframe
296
+ # :return: (df) pandas dataframe
297
+ # """
298
+ # df = data.copy()
299
+ # turbulence_index = self.calculate_turbulence(df, time_period=time_period)
300
+ # df = df.merge(turbulence_index, on="date")
301
+ # df = df.sort_values(["date", "tic"]).reset_index(drop=True)
302
+ # return df
303
+
304
+ # def add_vix(self, data):
305
+ # vix_df = self.download_data(['vix'], self.start, self.end_date, self.time_interval)
306
+ # cleaned_vix = self.clean_data(vix_df)
307
+ # vix = cleaned_vix[['date','close']]
308
+ #
309
+ # df = data.copy()
310
+ # df = df.merge(vix, on="date")
311
+ # df = df.sort_values(["date", "tic"]).reset_index(drop=True)
312
+ #
313
+ # return df
314
+
315
+ # def df_to_array(self,df,tech_indicator_list):
316
+ # unique_ticker = df.tic.unique()
317
+ # print(unique_ticker)
318
+ # if_first_time = True
319
+ # for tic in unique_ticker:
320
+ # if if_first_time:
321
+ # price_array = df[df.tic==tic][['close']].values
322
+ # #price_ary = df[df.tic==tic]['close'].values
323
+ # tech_array = df[df.tic==tic][tech_indicator_list].values
324
+ # risk_array = df[df.tic==tic]['turbulence'].values
325
+ # if_first_time = False
326
+ # else:
327
+ # price_array = np.hstack([price_array, df[df.tic==tic][['close']].values])
328
+ # tech_array = np.hstack([tech_array, df[df.tic==tic][tech_indicator_list].values])
329
+ # print('Successfully transformed into array')
330
+ # return price_array, tech_array, risk_array
finnlp/data_processors/yahoofinance.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import pytz
6
+ import yfinance as yf
7
+
8
+ try:
9
+ import exchange_calendars as tc
10
+ except:
11
+ print(
12
+ "Cannot import exchange_calendars.",
13
+ "If you are using python>=3.7, please install it.",
14
+ )
15
+ import trading_calendars as tc
16
+
17
+ print("Use trading_calendars instead for yahoofinance processor..")
18
+
19
+ from finnlp.utils.config import (
20
+ BINANCE_BASE_URL,
21
+ TIME_ZONE_BERLIN,
22
+ TIME_ZONE_JAKARTA,
23
+ TIME_ZONE_PARIS,
24
+ TIME_ZONE_SELFDEFINED,
25
+ TIME_ZONE_SHANGHAI,
26
+ TIME_ZONE_USEASTERN,
27
+ USE_TIME_ZONE_SELFDEFINED,
28
+ )
29
+ from finnlp.data_processors._base import _Base, calc_time_zone
30
+
31
+
32
+ class Yahoofinance(_Base):
33
+ def __init__(
34
+ self,
35
+ data_source: str,
36
+ start_date: str,
37
+ end_date: str,
38
+ time_interval: str,
39
+ **kwargs,
40
+ ):
41
+ super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
42
+
43
+ def download_data(
44
+ self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
45
+ ):
46
+ self.time_zone = calc_time_zone(
47
+ ticker_list, TIME_ZONE_SELFDEFINED, USE_TIME_ZONE_SELFDEFINED
48
+ )
49
+ self.dataframe = pd.DataFrame()
50
+ for tic in ticker_list:
51
+ temp_df = yf.download(
52
+ tic,
53
+ start=self.start_date,
54
+ end=self.end_date,
55
+ interval=self.time_interval,
56
+ )
57
+ temp_df["tic"] = tic
58
+ self.dataframe = pd.concat([self.dataframe, temp_df], axis=0, join="outer")
59
+ self.dataframe.reset_index(inplace=True)
60
+ try:
61
+ self.dataframe.columns = [
62
+ "date",
63
+ "open",
64
+ "high",
65
+ "low",
66
+ "close",
67
+ "adjusted_close",
68
+ "volume",
69
+ "tic",
70
+ ]
71
+ except NotImplementedError:
72
+ print("the features are not supported currently")
73
+ self.dataframe["day"] = self.dataframe["date"].dt.dayofweek
74
+ print(self.dataframe)
75
+ self.dataframe["date"] = self.dataframe.date.apply(
76
+ lambda x: x.strftime("%Y-%m-%d")
77
+ )
78
+ self.dataframe.dropna(inplace=True)
79
+ self.dataframe.reset_index(drop=True, inplace=True)
80
+ print("Shape of DataFrame: ", self.dataframe.shape)
81
+ self.dataframe.sort_values(by=["date", "tic"], inplace=True)
82
+ self.dataframe.reset_index(drop=True, inplace=True)
83
+
84
+ self.save_data(save_path)
85
+
86
+ print(
87
+ f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
88
+ )
89
+
90
+ def clean_data(self):
91
+ df = self.dataframe.copy()
92
+ df = df.rename(columns={"date": "time"})
93
+ time_interval = self.time_interval
94
+ tic_list = np.unique(df.tic.values)
95
+ trading_days = self.get_trading_days(start=self.start_date, end=self.end_date)
96
+ if time_interval == "1D":
97
+ times = trading_days
98
+ elif time_interval == "1Min":
99
+ times = []
100
+ for day in trading_days:
101
+ current_time = pd.Timestamp(day + " 09:30:00").tz_localize(
102
+ self.time_zone
103
+ )
104
+ for _ in range(390):
105
+ times.append(current_time)
106
+ current_time += pd.Timedelta(minutes=1)
107
+ else:
108
+ raise ValueError(
109
+ "Data clean at given time interval is not supported for YahooFinance data."
110
+ )
111
+ new_df = pd.DataFrame()
112
+ for tic in tic_list:
113
+ print(("Clean data for ") + tic)
114
+ tmp_df = pd.DataFrame(
115
+ columns=[
116
+ "open",
117
+ "high",
118
+ "low",
119
+ "close",
120
+ "adjusted_close",
121
+ "volume",
122
+ ],
123
+ index=times,
124
+ )
125
+ # get data for current ticker
126
+ tic_df = df[df.tic == tic]
127
+ # fill empty DataFrame using orginal data
128
+ for i in range(tic_df.shape[0]):
129
+ tmp_df.loc[tic_df.iloc[i]["time"]] = tic_df.iloc[i][
130
+ [
131
+ "open",
132
+ "high",
133
+ "low",
134
+ "close",
135
+ "adjusted_close",
136
+ "volume",
137
+ ]
138
+ ]
139
+
140
+ # if close on start date is NaN, fill data with first valid close
141
+ # and set volume to 0.
142
+ if str(tmp_df.iloc[0]["close"]) == "nan":
143
+ print("NaN data on start date, fill using first valid data.")
144
+ for i in range(tmp_df.shape[0]):
145
+ if str(tmp_df.iloc[i]["close"]) != "nan":
146
+ first_valid_close = tmp_df.iloc[i]["close"]
147
+ first_valid_adjclose = tmp_df.iloc[i]["adjusted_close"]
148
+
149
+ tmp_df.iloc[0] = [
150
+ first_valid_close,
151
+ first_valid_close,
152
+ first_valid_close,
153
+ first_valid_close,
154
+ first_valid_adjclose,
155
+ 0.0,
156
+ ]
157
+
158
+ # fill NaN data with previous close and set volume to 0.
159
+ for i in range(tmp_df.shape[0]):
160
+ if str(tmp_df.iloc[i]["close"]) == "nan":
161
+ previous_close = tmp_df.iloc[i - 1]["close"]
162
+ previous_adjusted_close = tmp_df.iloc[i - 1]["adjusted_close"]
163
+ if str(previous_close) == "nan":
164
+ raise ValueError
165
+ tmp_df.iloc[i] = [
166
+ previous_close,
167
+ previous_close,
168
+ previous_close,
169
+ previous_close,
170
+ previous_adjusted_close,
171
+ 0.0,
172
+ ]
173
+
174
+ # merge single ticker data to new DataFrame
175
+ tmp_df = tmp_df.astype(float)
176
+ tmp_df["tic"] = tic
177
+ new_df = new_df.append(tmp_df)
178
+
179
+ print(("Data clean for ") + tic + (" is finished."))
180
+
181
+ # reset index and rename columns
182
+ new_df = new_df.reset_index()
183
+ new_df = new_df.rename(columns={"index": "time"})
184
+ print("Data clean all finished!")
185
+ self.dataframe = new_df
186
+
187
+ def get_trading_days(self, start, end):
188
+ nyse = tc.get_calendar("NYSE")
189
+ df = nyse.sessions_in_range(pd.Timestamp(start), pd.Timestamp(end))
190
+ return [str(day)[:10] for day in df]