Spaces:
Runtime error
Runtime error
Commit
·
de6e775
1
Parent(s):
c31337d
Upload 19 files
Browse files- finnlp/data_processors/__pycache__/__init__.cpython-310.pyc +0 -0
- finnlp/data_processors/__pycache__/_base.cpython-310.pyc +0 -0
- finnlp/data_processors/__pycache__/yahoofinance.cpython-310.pyc +0 -0
- finnlp/data_processors/_base.py +637 -0
- finnlp/data_processors/akshare.py +148 -0
- finnlp/data_processors/alpaca.py +441 -0
- finnlp/data_processors/alphavantage.py +92 -0
- finnlp/data_processors/baostock.py +114 -0
- finnlp/data_processors/binance.py +434 -0
- finnlp/data_processors/ccxt.py +143 -0
- finnlp/data_processors/fx.py +102 -0
- finnlp/data_processors/iexcloud.py +143 -0
- finnlp/data_processors/joinquant.py +68 -0
- finnlp/data_processors/quandl.py +85 -0
- finnlp/data_processors/quantconnect.py +70 -0
- finnlp/data_processors/ricequant.py +131 -0
- finnlp/data_processors/tushare.py +318 -0
- finnlp/data_processors/wrds.py +330 -0
- finnlp/data_processors/yahoofinance.py +190 -0
finnlp/data_processors/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (172 Bytes). View file
|
|
finnlp/data_processors/__pycache__/_base.cpython-310.pyc
ADDED
Binary file (12.6 kB). View file
|
|
finnlp/data_processors/__pycache__/yahoofinance.cpython-310.pyc
ADDED
Binary file (4.66 kB). View file
|
|
finnlp/data_processors/_base.py
ADDED
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import zipfile
|
5 |
+
from datetime import *
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import stockstats
|
12 |
+
import talib
|
13 |
+
|
14 |
+
from finnlp.utils.config import BINANCE_BASE_URL
|
15 |
+
from finnlp.utils.config import TIME_ZONE_BERLIN
|
16 |
+
from finnlp.utils.config import TIME_ZONE_JAKARTA
|
17 |
+
from finnlp.utils.config import TIME_ZONE_PARIS
|
18 |
+
from finnlp.utils.config import TIME_ZONE_SELFDEFINED
|
19 |
+
from finnlp.utils.config import TIME_ZONE_SHANGHAI
|
20 |
+
from finnlp.utils.config import TIME_ZONE_USEASTERN
|
21 |
+
from finnlp.utils.config import USE_TIME_ZONE_SELFDEFINED
|
22 |
+
from finnlp.utils.config_tickers import CAC_40_TICKER
|
23 |
+
from finnlp.utils.config_tickers import CSI_300_TICKER
|
24 |
+
from finnlp.utils.config_tickers import DAX_30_TICKER
|
25 |
+
from finnlp.utils.config_tickers import DOW_30_TICKER
|
26 |
+
from finnlp.utils.config_tickers import HSI_50_TICKER
|
27 |
+
from finnlp.utils.config_tickers import LQ45_TICKER
|
28 |
+
from finnlp.utils.config_tickers import MDAX_50_TICKER
|
29 |
+
from finnlp.utils.config_tickers import NAS_100_TICKER
|
30 |
+
from finnlp.utils.config_tickers import SDAX_50_TICKER
|
31 |
+
from finnlp.utils.config_tickers import SP_500_TICKER
|
32 |
+
from finnlp.utils.config_tickers import SSE_50_TICKER
|
33 |
+
from finnlp.utils.config_tickers import TECDAX_TICKER
|
34 |
+
|
35 |
+
|
36 |
+
class _Base:
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
data_source: str,
|
40 |
+
start_date: str,
|
41 |
+
end_date: str,
|
42 |
+
time_interval: str,
|
43 |
+
**kwargs,
|
44 |
+
):
|
45 |
+
self.data_source: str = data_source
|
46 |
+
self.start_date: str = start_date
|
47 |
+
self.end_date: str = end_date
|
48 |
+
self.time_interval: str = time_interval # standard time_interval
|
49 |
+
# transferred_time_interval will be supported in the future.
|
50 |
+
# self.nonstandard_time_interval: str = self.calc_nonstandard_time_interval() # transferred time_interval of this processor
|
51 |
+
self.time_zone: str = ""
|
52 |
+
self.dataframe: pd.DataFrame = pd.DataFrame()
|
53 |
+
self.dictnumpy: dict = (
|
54 |
+
{}
|
55 |
+
) # e.g., self.dictnumpy["open"] = np.array([1, 2, 3]), self.dictnumpy["close"] = np.array([1, 2, 3])
|
56 |
+
|
57 |
+
def download_data(self, ticker_list: List[str]):
|
58 |
+
pass
|
59 |
+
|
60 |
+
def clean_data(self):
|
61 |
+
if "date" in self.dataframe.columns.values.tolist():
|
62 |
+
self.dataframe.rename(columns={"date": "time"}, inplace=True)
|
63 |
+
if "datetime" in self.dataframe.columns.values.tolist():
|
64 |
+
self.dataframe.rename(columns={"datetime": "time"}, inplace=True)
|
65 |
+
if self.data_source == "ccxt":
|
66 |
+
self.dataframe.rename(columns={"index": "time"}, inplace=True)
|
67 |
+
|
68 |
+
if self.data_source == "ricequant":
|
69 |
+
"""RiceQuant data is already cleaned, we only need to transform data format here.
|
70 |
+
No need for filling NaN data"""
|
71 |
+
self.dataframe.rename(columns={"order_book_id": "tic"}, inplace=True)
|
72 |
+
# raw df uses multi-index (tic,time), reset it to single index (time)
|
73 |
+
self.dataframe.reset_index(level=[0, 1], inplace=True)
|
74 |
+
# check if there is NaN values
|
75 |
+
assert not self.dataframe.isnull().values.any()
|
76 |
+
elif self.data_source == "baostock":
|
77 |
+
self.dataframe.rename(columns={"code": "tic"}, inplace=True)
|
78 |
+
|
79 |
+
self.dataframe.dropna(inplace=True)
|
80 |
+
# adjusted_close: adjusted close price
|
81 |
+
if "adjusted_close" not in self.dataframe.columns.values.tolist():
|
82 |
+
self.dataframe["adjusted_close"] = self.dataframe["close"]
|
83 |
+
self.dataframe.sort_values(by=["time", "tic"], inplace=True)
|
84 |
+
self.dataframe = self.dataframe[
|
85 |
+
[
|
86 |
+
"tic",
|
87 |
+
"time",
|
88 |
+
"open",
|
89 |
+
"high",
|
90 |
+
"low",
|
91 |
+
"close",
|
92 |
+
"adjusted_close",
|
93 |
+
"volume",
|
94 |
+
]
|
95 |
+
]
|
96 |
+
|
97 |
+
def fillna(self):
|
98 |
+
df = self.dataframe
|
99 |
+
|
100 |
+
dfcode = pd.DataFrame(columns=["tic"])
|
101 |
+
dfdate = pd.DataFrame(columns=["time"])
|
102 |
+
|
103 |
+
dfcode.tic = df.tic.unique()
|
104 |
+
dfdate.time = df.time.unique()
|
105 |
+
dfdate.sort_values(by="time", ascending=False, ignore_index=True, inplace=True)
|
106 |
+
|
107 |
+
# the old pandas may not support pd.merge(how="cross")
|
108 |
+
try:
|
109 |
+
df1 = pd.merge(dfcode, dfdate, how="cross")
|
110 |
+
except:
|
111 |
+
print("Please wait for a few seconds...")
|
112 |
+
df1 = pd.DataFrame(columns=["tic", "time"])
|
113 |
+
for i in range(dfcode.shape[0]):
|
114 |
+
for j in range(dfdate.shape[0]):
|
115 |
+
df1 = df1.append(
|
116 |
+
pd.DataFrame(
|
117 |
+
data={
|
118 |
+
"tic": dfcode.iat[i, 0],
|
119 |
+
"time": dfdate.iat[j, 0],
|
120 |
+
},
|
121 |
+
index=[(i + 1) * (j + 1) - 1],
|
122 |
+
)
|
123 |
+
)
|
124 |
+
|
125 |
+
df = pd.merge(df1, df, how="left", on=["tic", "time"])
|
126 |
+
|
127 |
+
# back fill missing data then front fill
|
128 |
+
df_new = pd.DataFrame(columns=df.columns)
|
129 |
+
for i in df.tic.unique():
|
130 |
+
df_tmp = df[df.tic == i].fillna(method="bfill").fillna(method="ffill")
|
131 |
+
df_new = pd.concat([df_new, df_tmp], ignore_index=True)
|
132 |
+
|
133 |
+
df_new = df_new.fillna(0)
|
134 |
+
|
135 |
+
# reshape dataframe
|
136 |
+
df_new = df_new.sort_values(by=["time", "tic"]).reset_index(drop=True)
|
137 |
+
|
138 |
+
print("Shape of DataFrame: ", df_new.shape)
|
139 |
+
|
140 |
+
self.dataframe = df_new
|
141 |
+
|
142 |
+
def get_trading_days(self, start: str, end: str) -> List[str]:
|
143 |
+
if self.data_source in [
|
144 |
+
"binance",
|
145 |
+
"ccxt",
|
146 |
+
"quantconnect",
|
147 |
+
"ricequant",
|
148 |
+
"tushare",
|
149 |
+
]:
|
150 |
+
print(
|
151 |
+
f"Calculate get_trading_days not supported for {self.data_source} yet."
|
152 |
+
)
|
153 |
+
return None
|
154 |
+
|
155 |
+
# select_stockstats_talib: 0 (stockstats, default), or 1 (use talib). Users can choose the method.
|
156 |
+
# drop_na_timestep: 0 (not dropping timesteps that contain nan), or 1 (dropping timesteps that contain nan, default). Users can choose the method.
|
157 |
+
def add_technical_indicator(
|
158 |
+
self,
|
159 |
+
tech_indicator_list: List[str],
|
160 |
+
select_stockstats_talib: int = 0,
|
161 |
+
drop_na_timesteps: int = 1,
|
162 |
+
):
|
163 |
+
"""
|
164 |
+
calculate technical indicators
|
165 |
+
use stockstats/talib package to add technical inidactors
|
166 |
+
:param data: (df) pandas dataframe
|
167 |
+
:return: (df) pandas dataframe
|
168 |
+
"""
|
169 |
+
if "date" in self.dataframe.columns.values.tolist():
|
170 |
+
self.dataframe.rename(columns={"date": "time"}, inplace=True)
|
171 |
+
|
172 |
+
if self.data_source == "ccxt":
|
173 |
+
self.dataframe.rename(columns={"index": "time"}, inplace=True)
|
174 |
+
|
175 |
+
self.dataframe.reset_index(drop=False, inplace=True)
|
176 |
+
if "level_1" in self.dataframe.columns:
|
177 |
+
self.dataframe.drop(columns=["level_1"], inplace=True)
|
178 |
+
if "level_0" in self.dataframe.columns and "tic" not in self.dataframe.columns:
|
179 |
+
self.dataframe.rename(columns={"level_0": "tic"}, inplace=True)
|
180 |
+
assert select_stockstats_talib in {0, 1}
|
181 |
+
print("tech_indicator_list: ", tech_indicator_list)
|
182 |
+
if select_stockstats_talib == 0: # use stockstats
|
183 |
+
stock = stockstats.StockDataFrame.retype(self.dataframe)
|
184 |
+
unique_ticker = stock.tic.unique()
|
185 |
+
for indicator in tech_indicator_list:
|
186 |
+
print("indicator: ", indicator)
|
187 |
+
indicator_df = pd.DataFrame()
|
188 |
+
for i in range(len(unique_ticker)):
|
189 |
+
try:
|
190 |
+
temp_indicator = stock[stock.tic == unique_ticker[i]][indicator]
|
191 |
+
temp_indicator = pd.DataFrame(temp_indicator)
|
192 |
+
temp_indicator["tic"] = unique_ticker[i]
|
193 |
+
temp_indicator["time"] = self.dataframe[
|
194 |
+
self.dataframe.tic == unique_ticker[i]
|
195 |
+
]["time"].to_list()
|
196 |
+
indicator_df = pd.concat(
|
197 |
+
[indicator_df, temp_indicator],
|
198 |
+
axis=0,
|
199 |
+
join="outer",
|
200 |
+
ignore_index=True,
|
201 |
+
)
|
202 |
+
except Exception as e:
|
203 |
+
print(e)
|
204 |
+
if not indicator_df.empty:
|
205 |
+
self.dataframe = self.dataframe.merge(
|
206 |
+
indicator_df[["tic", "time", indicator]],
|
207 |
+
on=["tic", "time"],
|
208 |
+
how="left",
|
209 |
+
)
|
210 |
+
else: # use talib
|
211 |
+
final_df = pd.DataFrame()
|
212 |
+
for i in self.dataframe.tic.unique():
|
213 |
+
tic_df = self.dataframe[self.dataframe.tic == i]
|
214 |
+
(
|
215 |
+
tic_df.loc["macd"],
|
216 |
+
tic_df.loc["macd_signal"],
|
217 |
+
tic_df.loc["macd_hist"],
|
218 |
+
) = talib.MACD(
|
219 |
+
tic_df["close"],
|
220 |
+
fastperiod=12,
|
221 |
+
slowperiod=26,
|
222 |
+
signalperiod=9,
|
223 |
+
)
|
224 |
+
tic_df.loc["rsi"] = talib.RSI(tic_df["close"], timeperiod=14)
|
225 |
+
tic_df.loc["cci"] = talib.CCI(
|
226 |
+
tic_df["high"],
|
227 |
+
tic_df["low"],
|
228 |
+
tic_df["close"],
|
229 |
+
timeperiod=14,
|
230 |
+
)
|
231 |
+
tic_df.loc["dx"] = talib.DX(
|
232 |
+
tic_df["high"],
|
233 |
+
tic_df["low"],
|
234 |
+
tic_df["close"],
|
235 |
+
timeperiod=14,
|
236 |
+
)
|
237 |
+
final_df = pd.concat([final_df, tic_df], axis=0, join="outer")
|
238 |
+
self.dataframe = final_df
|
239 |
+
|
240 |
+
self.dataframe.sort_values(by=["time", "tic"], inplace=True)
|
241 |
+
if drop_na_timesteps:
|
242 |
+
time_to_drop = self.dataframe[
|
243 |
+
self.dataframe.isna().any(axis=1)
|
244 |
+
].time.unique()
|
245 |
+
self.dataframe = self.dataframe[~self.dataframe.time.isin(time_to_drop)]
|
246 |
+
print("Succesfully add technical indicators")
|
247 |
+
|
248 |
+
def add_turbulence(self):
|
249 |
+
"""
|
250 |
+
add turbulence index from a precalcualted dataframe
|
251 |
+
:param data: (df) pandas dataframe
|
252 |
+
:return: (df) pandas dataframe
|
253 |
+
"""
|
254 |
+
# df = data.copy()
|
255 |
+
# turbulence_index = self.calculate_turbulence(df)
|
256 |
+
# df = df.merge(turbulence_index, on="time")
|
257 |
+
# df = df.sort_values(["time", "tic"]).reset_index(drop=True)
|
258 |
+
# return df
|
259 |
+
if self.data_source in [
|
260 |
+
"binance",
|
261 |
+
"ccxt",
|
262 |
+
"iexcloud",
|
263 |
+
"joinquant",
|
264 |
+
"quantconnect",
|
265 |
+
]:
|
266 |
+
print(
|
267 |
+
f"Turbulence not supported for {self.data_source} yet. Return original DataFrame."
|
268 |
+
)
|
269 |
+
if self.data_source in [
|
270 |
+
"alpaca",
|
271 |
+
"ricequant",
|
272 |
+
"tushare",
|
273 |
+
"wrds",
|
274 |
+
"yahoofinance",
|
275 |
+
]:
|
276 |
+
turbulence_index = self.calculate_turbulence()
|
277 |
+
self.dataframe = self.dataframe.merge(turbulence_index, on="time")
|
278 |
+
self.dataframe.sort_values(["time", "tic"], inplace=True)
|
279 |
+
self.dataframe.reset_index(drop=True, inplace=True)
|
280 |
+
|
281 |
+
def calculate_turbulence(self, time_period: int = 252) -> pd.DataFrame:
|
282 |
+
"""calculate turbulence index based on dow 30"""
|
283 |
+
# can add other market assets
|
284 |
+
df_price_pivot = self.dataframe.pivot(
|
285 |
+
index="time", columns="tic", values="close"
|
286 |
+
)
|
287 |
+
# use returns to calculate turbulence
|
288 |
+
df_price_pivot = df_price_pivot.pct_change()
|
289 |
+
|
290 |
+
unique_date = self.dataframe["time"].unique()
|
291 |
+
# start after a year
|
292 |
+
start = time_period
|
293 |
+
turbulence_index = [0] * start
|
294 |
+
# turbulence_index = [0]
|
295 |
+
count = 0
|
296 |
+
for i in range(start, len(unique_date)):
|
297 |
+
current_price = df_price_pivot[df_price_pivot.index == unique_date[i]]
|
298 |
+
# use one year rolling window to calcualte covariance
|
299 |
+
hist_price = df_price_pivot[
|
300 |
+
(df_price_pivot.index < unique_date[i])
|
301 |
+
& (df_price_pivot.index >= unique_date[i - time_period])
|
302 |
+
]
|
303 |
+
# Drop tickers which has number missing values more than the "oldest" ticker
|
304 |
+
filtered_hist_price = hist_price.iloc[
|
305 |
+
hist_price.isna().sum().min() :
|
306 |
+
].dropna(axis=1)
|
307 |
+
|
308 |
+
cov_temp = filtered_hist_price.cov()
|
309 |
+
current_temp = current_price[list(filtered_hist_price)] - np.mean(
|
310 |
+
filtered_hist_price, axis=0
|
311 |
+
)
|
312 |
+
# cov_temp = hist_price.cov()
|
313 |
+
# current_temp=(current_price - np.mean(hist_price,axis=0))
|
314 |
+
|
315 |
+
temp = current_temp.values.dot(np.linalg.pinv(cov_temp)).dot(
|
316 |
+
current_temp.values.T
|
317 |
+
)
|
318 |
+
if temp > 0:
|
319 |
+
count += 1
|
320 |
+
# avoid large outlier because of the calculation just begins: else turbulence_temp = 0
|
321 |
+
turbulence_temp = temp[0][0] if count > 2 else 0
|
322 |
+
else:
|
323 |
+
turbulence_temp = 0
|
324 |
+
turbulence_index.append(turbulence_temp)
|
325 |
+
|
326 |
+
turbulence_index = pd.DataFrame(
|
327 |
+
{"time": df_price_pivot.index, "turbulence": turbulence_index}
|
328 |
+
)
|
329 |
+
return turbulence_index
|
330 |
+
|
331 |
+
def add_vix(self):
|
332 |
+
"""
|
333 |
+
add vix from processors
|
334 |
+
:param data: (df) pandas dataframe
|
335 |
+
:return: (df) pandas dataframe
|
336 |
+
"""
|
337 |
+
if self.data_source in [
|
338 |
+
"binance",
|
339 |
+
"ccxt",
|
340 |
+
"iexcloud",
|
341 |
+
"joinquant",
|
342 |
+
"quantconnect",
|
343 |
+
"ricequant",
|
344 |
+
"tushare",
|
345 |
+
]:
|
346 |
+
print(
|
347 |
+
f"VIX is not applicable for {self.data_source}. Return original DataFrame"
|
348 |
+
)
|
349 |
+
return None
|
350 |
+
|
351 |
+
# if self.data_source == 'yahoofinance':
|
352 |
+
# df = data.copy()
|
353 |
+
# df_vix = self.download_data(
|
354 |
+
# start_date=df.time.min(),
|
355 |
+
# end_date=df.time.max(),
|
356 |
+
# ticker_list=["^VIX"],
|
357 |
+
# time_interval=self.time_interval,
|
358 |
+
# )
|
359 |
+
# df_vix = self.clean_data(df_vix)
|
360 |
+
# vix = df_vix[["time", "adjusted_close"]]
|
361 |
+
# vix.columns = ["time", "vix"]
|
362 |
+
#
|
363 |
+
# df = df.merge(vix, on="time")
|
364 |
+
# df = df.sort_values(["time", "tic"]).reset_index(drop=True)
|
365 |
+
# elif self.data_source == 'alpaca':
|
366 |
+
# vix_df = self.download_data(["VIXY"], self.start, self.end, self.time_interval)
|
367 |
+
# cleaned_vix = self.clean_data(vix_df)
|
368 |
+
# vix = cleaned_vix[["time", "close"]]
|
369 |
+
# vix = vix.rename(columns={"close": "VIXY"})
|
370 |
+
#
|
371 |
+
# df = data.copy()
|
372 |
+
# df = df.merge(vix, on="time")
|
373 |
+
# df = df.sort_values(["time", "tic"]).reset_index(drop=True)
|
374 |
+
# elif self.data_source == 'wrds':
|
375 |
+
# vix_df = self.download_data(['vix'], self.start, self.end_date, self.time_interval)
|
376 |
+
# cleaned_vix = self.clean_data(vix_df)
|
377 |
+
# vix = cleaned_vix[['date', 'close']]
|
378 |
+
#
|
379 |
+
# df = data.copy()
|
380 |
+
# df = df.merge(vix, on="date")
|
381 |
+
# df = df.sort_values(["date", "tic"]).reset_index(drop=True)
|
382 |
+
|
383 |
+
elif self.data_source == "yahoofinance":
|
384 |
+
ticker = "^VIX"
|
385 |
+
elif self.data_source == "alpaca":
|
386 |
+
ticker = "VIXY"
|
387 |
+
elif self.data_source == "wrds":
|
388 |
+
ticker = "vix"
|
389 |
+
else:
|
390 |
+
pass
|
391 |
+
df = self.dataframe.copy()
|
392 |
+
self.dataframe = [ticker]
|
393 |
+
# self.download_data(self.start_date, self.end_date, self.time_interval)
|
394 |
+
self.download_data([ticker], save_path="./data/vix.csv")
|
395 |
+
self.clean_data()
|
396 |
+
cleaned_vix = self.dataframe
|
397 |
+
# .rename(columns={ticker: "vix"})
|
398 |
+
vix = cleaned_vix[["time", "close"]]
|
399 |
+
cleaned_vix = vix.rename(columns={"close": "vix"})
|
400 |
+
|
401 |
+
df = df.merge(cleaned_vix, on="time")
|
402 |
+
df = df.sort_values(["time", "tic"]).reset_index(drop=True)
|
403 |
+
self.dataframe = df
|
404 |
+
|
405 |
+
def df_to_array(self, tech_indicator_list: List[str], if_vix: bool):
|
406 |
+
unique_ticker = self.dataframe.tic.unique()
|
407 |
+
price_array = np.column_stack(
|
408 |
+
[self.dataframe[self.dataframe.tic == tic].close for tic in unique_ticker]
|
409 |
+
)
|
410 |
+
common_tech_indicator_list = [
|
411 |
+
i
|
412 |
+
for i in tech_indicator_list
|
413 |
+
if i in self.dataframe.columns.values.tolist()
|
414 |
+
]
|
415 |
+
tech_array = np.hstack(
|
416 |
+
[
|
417 |
+
self.dataframe.loc[
|
418 |
+
(self.dataframe.tic == tic), common_tech_indicator_list
|
419 |
+
]
|
420 |
+
for tic in unique_ticker
|
421 |
+
]
|
422 |
+
)
|
423 |
+
if if_vix:
|
424 |
+
risk_array = np.column_stack(
|
425 |
+
[self.dataframe[self.dataframe.tic == tic].vix for tic in unique_ticker]
|
426 |
+
)
|
427 |
+
else:
|
428 |
+
risk_array = (
|
429 |
+
np.column_stack(
|
430 |
+
[
|
431 |
+
self.dataframe[self.dataframe.tic == tic].turbulence
|
432 |
+
for tic in unique_ticker
|
433 |
+
]
|
434 |
+
)
|
435 |
+
if "turbulence" in self.dataframe.columns
|
436 |
+
else None
|
437 |
+
)
|
438 |
+
print("Successfully transformed into array")
|
439 |
+
return price_array, tech_array, risk_array
|
440 |
+
|
441 |
+
# standard_time_interval s: second, m: minute, h: hour, d: day, w: week, M: month, q: quarter, y: year
|
442 |
+
# output time_interval of the processor
|
443 |
+
def calc_nonstandard_time_interval(self) -> str:
|
444 |
+
if self.data_source == "alpaca":
|
445 |
+
pass
|
446 |
+
elif self.data_source == "baostock":
|
447 |
+
# nonstandard_time_interval: 默认为d,日k线;d=日k线、w=周、m=月、5=5分钟、15=15分钟、30=30分钟、60=60分钟k线数据,不区分大小写;指数没有分钟线数据;周线每周最后一个交易日才可以获取,月线每月最后一个交易日才可以获取。
|
448 |
+
pass
|
449 |
+
time_intervals = ["5m", "15m", "30m", "60m", "1d", "1w", "1M"]
|
450 |
+
assert self.time_interval in time_intervals, (
|
451 |
+
"This time interval is not supported. Supported time intervals: "
|
452 |
+
+ ",".join(time_intervals)
|
453 |
+
)
|
454 |
+
if (
|
455 |
+
"d" in self.time_interval
|
456 |
+
or "w" in self.time_interval
|
457 |
+
or "M" in self.time_interval
|
458 |
+
):
|
459 |
+
return self.time_interval[-1:].lower()
|
460 |
+
elif "m" in self.time_interval:
|
461 |
+
return self.time_interval[:-1]
|
462 |
+
elif self.data_source == "binance":
|
463 |
+
# nonstandard_time_interval: 1m,3m,5m,15m,30m,1h,2h,4h,6h,8h,12h,1d,3d,1w,1M
|
464 |
+
time_intervals = [
|
465 |
+
"1m",
|
466 |
+
"3m",
|
467 |
+
"5m",
|
468 |
+
"15m",
|
469 |
+
"30m",
|
470 |
+
"1h",
|
471 |
+
"2h",
|
472 |
+
"4h",
|
473 |
+
"6h",
|
474 |
+
"8h",
|
475 |
+
"12h",
|
476 |
+
"1d",
|
477 |
+
"3d",
|
478 |
+
"1w",
|
479 |
+
"1M",
|
480 |
+
]
|
481 |
+
assert self.time_interval in time_intervals, (
|
482 |
+
"This time interval is not supported. Supported time intervals: "
|
483 |
+
+ ",".join(time_intervals)
|
484 |
+
)
|
485 |
+
return self.time_interval
|
486 |
+
elif self.data_source == "ccxt":
|
487 |
+
pass
|
488 |
+
elif self.data_source == "iexcloud":
|
489 |
+
time_intervals = ["1d"]
|
490 |
+
assert self.time_interval in time_intervals, (
|
491 |
+
"This time interval is not supported. Supported time intervals: "
|
492 |
+
+ ",".join(time_intervals)
|
493 |
+
)
|
494 |
+
return self.time_interval.upper()
|
495 |
+
elif self.data_source == "joinquant":
|
496 |
+
# '1m', '5m', '15m', '30m', '60m', '120m', '1d', '1w', '1M'
|
497 |
+
time_intervals = [
|
498 |
+
"1m",
|
499 |
+
"5m",
|
500 |
+
"15m",
|
501 |
+
"30m",
|
502 |
+
"60m",
|
503 |
+
"120m",
|
504 |
+
"1d",
|
505 |
+
"1w",
|
506 |
+
"1M",
|
507 |
+
]
|
508 |
+
assert self.time_interval in time_intervals, (
|
509 |
+
"This time interval is not supported. Supported time intervals: "
|
510 |
+
+ ",".join(time_intervals)
|
511 |
+
)
|
512 |
+
return self.time_interval
|
513 |
+
elif self.data_source == "quantconnect":
|
514 |
+
pass
|
515 |
+
elif self.data_source == "ricequant":
|
516 |
+
# nonstandard_time_interval: 'd' - 天,'w' - 周,'m' - 月, 'q' - 季,'y' - 年
|
517 |
+
time_intervals = ["d", "w", "M", "q", "y"]
|
518 |
+
assert self.time_interval[-1] in time_intervals, (
|
519 |
+
"This time interval is not supported. Supported time intervals: "
|
520 |
+
+ ",".join(time_intervals)
|
521 |
+
)
|
522 |
+
if "M" in self.time_interval:
|
523 |
+
return self.time_interval.lower()
|
524 |
+
else:
|
525 |
+
return self.time_interval
|
526 |
+
elif self.data_source == "tushare":
|
527 |
+
# 分钟频度包括1分、5、15、30、60分数据. Not support currently.
|
528 |
+
# time_intervals = ["1m", "5m", "15m", "30m", "60m", "1d"]
|
529 |
+
time_intervals = ["1d"]
|
530 |
+
assert self.time_interval in time_intervals, (
|
531 |
+
"This time interval is not supported. Supported time intervals: "
|
532 |
+
+ ",".join(time_intervals)
|
533 |
+
)
|
534 |
+
return self.time_interval
|
535 |
+
elif self.data_source == "wrds":
|
536 |
+
pass
|
537 |
+
elif self.data_source == "yahoofinance":
|
538 |
+
# nonstandard_time_interval: ["1m", "2m", "5m", "15m", "30m", "60m", "90m", "1h", "1d", "5d","1wk", "1mo", "3mo"]
|
539 |
+
time_intervals = [
|
540 |
+
"1m",
|
541 |
+
"2m",
|
542 |
+
"5m",
|
543 |
+
"15m",
|
544 |
+
"30m",
|
545 |
+
"60m",
|
546 |
+
"90m",
|
547 |
+
"1h",
|
548 |
+
"1d",
|
549 |
+
"5d",
|
550 |
+
"1w",
|
551 |
+
"1M",
|
552 |
+
"3M",
|
553 |
+
]
|
554 |
+
assert self.time_interval in time_intervals, (
|
555 |
+
"This time interval is not supported. Supported time intervals: "
|
556 |
+
+ ",".join(time_intervals)
|
557 |
+
)
|
558 |
+
if "w" in self.time_interval:
|
559 |
+
return self.time_interval + "k"
|
560 |
+
elif "M" in self.time_interval:
|
561 |
+
return self.time_interval[:-1] + "mo"
|
562 |
+
else:
|
563 |
+
return self.time_interval
|
564 |
+
else:
|
565 |
+
raise ValueError(
|
566 |
+
f"Not support transfer_standard_time_interval for {self.data_source}"
|
567 |
+
)
|
568 |
+
|
569 |
+
# "600000.XSHG" -> "sh.600000"
|
570 |
+
# "000612.XSHE" -> "sz.000612"
|
571 |
+
def transfer_standard_ticker_to_nonstandard(self, ticker: str) -> str:
|
572 |
+
return ticker
|
573 |
+
|
574 |
+
def save_data(self, path):
|
575 |
+
if ".csv" in path:
|
576 |
+
path = path.split("/")
|
577 |
+
filename = path[-1]
|
578 |
+
path = "/".join(path[:-1] + [""])
|
579 |
+
else:
|
580 |
+
if path[-1] == "/":
|
581 |
+
filename = "dataset.csv"
|
582 |
+
else:
|
583 |
+
filename = "/dataset.csv"
|
584 |
+
|
585 |
+
os.makedirs(path, exist_ok=True)
|
586 |
+
self.dataframe.to_csv(path + filename, index=False)
|
587 |
+
|
588 |
+
def load_data(self, path):
|
589 |
+
assert ".csv" in path # only support csv format now
|
590 |
+
self.dataframe = pd.read_csv(path)
|
591 |
+
columns = self.dataframe.columns
|
592 |
+
print(f"{path} loaded")
|
593 |
+
# # check loaded file
|
594 |
+
# assert "date" in columns or "time" in columns
|
595 |
+
# assert "close" in columns
|
596 |
+
|
597 |
+
|
598 |
+
def calc_time_zone(
|
599 |
+
ticker_list: List[str],
|
600 |
+
time_zone_selfdefined: str,
|
601 |
+
use_time_zone_selfdefined: int,
|
602 |
+
) -> str:
|
603 |
+
assert isinstance(ticker_list, list)
|
604 |
+
ticker_list = ticker_list[0]
|
605 |
+
if use_time_zone_selfdefined == 1:
|
606 |
+
time_zone = time_zone_selfdefined
|
607 |
+
elif ticker_list in HSI_50_TICKER + SSE_50_TICKER + CSI_300_TICKER:
|
608 |
+
time_zone = TIME_ZONE_SHANGHAI
|
609 |
+
elif ticker_list in DOW_30_TICKER + NAS_100_TICKER + SP_500_TICKER:
|
610 |
+
time_zone = TIME_ZONE_USEASTERN
|
611 |
+
elif ticker_list == CAC_40_TICKER:
|
612 |
+
time_zone = TIME_ZONE_PARIS
|
613 |
+
elif ticker_list in DAX_30_TICKER + TECDAX_TICKER + MDAX_50_TICKER + SDAX_50_TICKER:
|
614 |
+
time_zone = TIME_ZONE_BERLIN
|
615 |
+
elif ticker_list == LQ45_TICKER:
|
616 |
+
time_zone = TIME_ZONE_JAKARTA
|
617 |
+
else:
|
618 |
+
# hack needed to have this working with vix indicator
|
619 |
+
# fix: unable to set time_zone_selfdefined from top-level dataprocessor class
|
620 |
+
time_zone = TIME_ZONE_USEASTERN
|
621 |
+
# raise ValueError("Time zone is wrong.")
|
622 |
+
return time_zone
|
623 |
+
|
624 |
+
|
625 |
+
def check_date(d: str) -> bool:
|
626 |
+
assert (
|
627 |
+
len(d) == 10
|
628 |
+
), "Please check the length of date and use the correct date like 2020-01-01."
|
629 |
+
indices = [0, 1, 2, 3, 5, 6, 8, 9]
|
630 |
+
correct = True
|
631 |
+
for i in indices:
|
632 |
+
if not d[i].isdigit():
|
633 |
+
correct = False
|
634 |
+
break
|
635 |
+
if not correct:
|
636 |
+
raise ValueError("Please use the correct date like 2020-01-01.")
|
637 |
+
return correct
|
finnlp/data_processors/akshare.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
warnings.filterwarnings("ignore")
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
import pandas as pd
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
import stockstats
|
13 |
+
import talib
|
14 |
+
from meta.data_processors._base import _Base
|
15 |
+
|
16 |
+
import akshare as ak # pip install akshare
|
17 |
+
|
18 |
+
|
19 |
+
class Akshare(_Base):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
data_source: str,
|
23 |
+
start_date: str,
|
24 |
+
end_date: str,
|
25 |
+
time_interval: str,
|
26 |
+
**kwargs,
|
27 |
+
):
|
28 |
+
start_date = self.transfer_date(start_date)
|
29 |
+
end_date = self.transfer_date(end_date)
|
30 |
+
|
31 |
+
super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
|
32 |
+
|
33 |
+
if "adj" in kwargs.keys():
|
34 |
+
self.adj = kwargs["adj"]
|
35 |
+
print(f"Using {self.adj} method.")
|
36 |
+
else:
|
37 |
+
self.adj = ""
|
38 |
+
|
39 |
+
if "period" in kwargs.keys():
|
40 |
+
self.period = kwargs["period"]
|
41 |
+
else:
|
42 |
+
self.period = "daily"
|
43 |
+
|
44 |
+
def get_data(self, id) -> pd.DataFrame:
|
45 |
+
return ak.stock_zh_a_hist(
|
46 |
+
symbol=id,
|
47 |
+
period=self.time_interval,
|
48 |
+
start_date=self.start_date,
|
49 |
+
end_date=self.end_date,
|
50 |
+
adjust=self.adj,
|
51 |
+
)
|
52 |
+
|
53 |
+
def download_data(
|
54 |
+
self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
|
55 |
+
):
|
56 |
+
"""
|
57 |
+
`pd.DataFrame`
|
58 |
+
7 columns: A tick symbol, time, open, high, low, close and volume
|
59 |
+
for the specified stock ticker
|
60 |
+
"""
|
61 |
+
assert self.time_interval in [
|
62 |
+
"daily",
|
63 |
+
"weekly",
|
64 |
+
"monthly",
|
65 |
+
], "Not supported currently"
|
66 |
+
|
67 |
+
self.ticker_list = ticker_list
|
68 |
+
|
69 |
+
self.dataframe = pd.DataFrame()
|
70 |
+
for i in tqdm(ticker_list, total=len(ticker_list)):
|
71 |
+
nonstandard_id = self.transfer_standard_ticker_to_nonstandard(i)
|
72 |
+
df_temp = self.get_data(nonstandard_id)
|
73 |
+
df_temp["tic"] = i
|
74 |
+
# df_temp = self.get_data(i)
|
75 |
+
self.dataframe = pd.concat([self.dataframe, df_temp])
|
76 |
+
# self.dataframe = self.dataframe.append(df_temp)
|
77 |
+
# print("{} ok".format(i))
|
78 |
+
time.sleep(0.25)
|
79 |
+
|
80 |
+
self.dataframe.columns = [
|
81 |
+
"time",
|
82 |
+
"open",
|
83 |
+
"close",
|
84 |
+
"high",
|
85 |
+
"low",
|
86 |
+
"volume",
|
87 |
+
"amount",
|
88 |
+
"amplitude",
|
89 |
+
"pct_chg",
|
90 |
+
"change",
|
91 |
+
"turnover",
|
92 |
+
"tic",
|
93 |
+
]
|
94 |
+
|
95 |
+
self.dataframe.sort_values(by=["time", "tic"], inplace=True)
|
96 |
+
self.dataframe.reset_index(drop=True, inplace=True)
|
97 |
+
|
98 |
+
self.dataframe = self.dataframe[
|
99 |
+
["tic", "time", "open", "high", "low", "close", "volume"]
|
100 |
+
]
|
101 |
+
# self.dataframe.loc[:, 'tic'] = pd.DataFrame((self.dataframe['tic'].tolist()))
|
102 |
+
self.dataframe["time"] = pd.to_datetime(
|
103 |
+
self.dataframe["time"], format="%Y-%m-%d"
|
104 |
+
)
|
105 |
+
self.dataframe["day"] = self.dataframe["time"].dt.dayofweek
|
106 |
+
self.dataframe["time"] = self.dataframe.time.apply(
|
107 |
+
lambda x: x.strftime("%Y-%m-%d")
|
108 |
+
)
|
109 |
+
|
110 |
+
self.dataframe.dropna(inplace=True)
|
111 |
+
self.dataframe.sort_values(by=["time", "tic"], inplace=True)
|
112 |
+
self.dataframe.reset_index(drop=True, inplace=True)
|
113 |
+
|
114 |
+
self.save_data(save_path)
|
115 |
+
|
116 |
+
print(
|
117 |
+
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
|
118 |
+
)
|
119 |
+
|
120 |
+
def data_split(self, df, start, end, target_date_col="time"):
|
121 |
+
"""
|
122 |
+
split the dataset into training or testing using time
|
123 |
+
:param data: (df) pandas dataframe, start, end
|
124 |
+
:return: (df) pandas dataframe
|
125 |
+
"""
|
126 |
+
data = df[(df[target_date_col] >= start) & (df[target_date_col] < end)]
|
127 |
+
data = data.sort_values([target_date_col, "tic"], ignore_index=True)
|
128 |
+
data.index = data[target_date_col].factorize()[0]
|
129 |
+
return data
|
130 |
+
|
131 |
+
def transfer_standard_ticker_to_nonstandard(self, ticker: str) -> str:
|
132 |
+
# "600000.XSHG" -> "600000"
|
133 |
+
# "000612.XSHE" -> "000612"
|
134 |
+
# "600000.SH" -> "600000"
|
135 |
+
# "000612.SZ" -> "000612"
|
136 |
+
if "." in ticker:
|
137 |
+
n, alpha = ticker.split(".")
|
138 |
+
# assert alpha in ["XSHG", "XSHE"], "Wrong alpha"
|
139 |
+
return n
|
140 |
+
|
141 |
+
def transfer_date(self, time: str) -> str:
|
142 |
+
if "-" in time:
|
143 |
+
time = "".join(time.split("-"))
|
144 |
+
elif "." in time:
|
145 |
+
time = "".join(time.split("."))
|
146 |
+
elif "/" in time:
|
147 |
+
time = "".join(time.split("/"))
|
148 |
+
return time
|
finnlp/data_processors/alpaca.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import alpaca_trade_api as tradeapi
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import pytz
|
7 |
+
|
8 |
+
try:
|
9 |
+
import exchange_calendars as tc
|
10 |
+
except:
|
11 |
+
print(
|
12 |
+
"Cannot import exchange_calendars.",
|
13 |
+
"If you are using python>=3.7, please install it.",
|
14 |
+
)
|
15 |
+
import trading_calendars as tc
|
16 |
+
|
17 |
+
print("Use trading_calendars instead for alpaca processor.")
|
18 |
+
# from basic_processor import _Base
|
19 |
+
from meta.data_processors._base import _Base
|
20 |
+
from meta.data_processors._base import calc_time_zone
|
21 |
+
|
22 |
+
from meta.config import (
|
23 |
+
TIME_ZONE_SHANGHAI,
|
24 |
+
TIME_ZONE_USEASTERN,
|
25 |
+
TIME_ZONE_PARIS,
|
26 |
+
TIME_ZONE_BERLIN,
|
27 |
+
TIME_ZONE_JAKARTA,
|
28 |
+
TIME_ZONE_SELFDEFINED,
|
29 |
+
USE_TIME_ZONE_SELFDEFINED,
|
30 |
+
BINANCE_BASE_URL,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class Alpaca(_Base):
|
35 |
+
# def __init__(self, API_KEY=None, API_SECRET=None, API_BASE_URL=None, api=None):
|
36 |
+
# if api is None:
|
37 |
+
# try:
|
38 |
+
# self.api = tradeapi.REST(API_KEY, API_SECRET, API_BASE_URL, "v2")
|
39 |
+
# except BaseException:
|
40 |
+
# raise ValueError("Wrong Account Info!")
|
41 |
+
# else:
|
42 |
+
# self.api = api
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
data_source: str,
|
46 |
+
start_date: str,
|
47 |
+
end_date: str,
|
48 |
+
time_interval: str,
|
49 |
+
**kwargs,
|
50 |
+
):
|
51 |
+
super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
|
52 |
+
if kwargs["API"] is None:
|
53 |
+
try:
|
54 |
+
self.api = tradeapi.REST(
|
55 |
+
kwargs["API_KEY"],
|
56 |
+
kwargs["API_SECRET"],
|
57 |
+
kwargs["API_BASE_URL"],
|
58 |
+
"v2",
|
59 |
+
)
|
60 |
+
except BaseException:
|
61 |
+
raise ValueError("Wrong Account Info!")
|
62 |
+
else:
|
63 |
+
self.api = kwargs["API"]
|
64 |
+
|
65 |
+
def download_data(
|
66 |
+
self,
|
67 |
+
ticker_list,
|
68 |
+
start_date,
|
69 |
+
end_date,
|
70 |
+
time_interval,
|
71 |
+
save_path: str = "./data/dataset.csv",
|
72 |
+
) -> pd.DataFrame:
|
73 |
+
self.time_zone = calc_time_zone(
|
74 |
+
ticker_list, TIME_ZONE_SELFDEFINED, USE_TIME_ZONE_SELFDEFINED
|
75 |
+
)
|
76 |
+
start_date = pd.Timestamp(self.start_date, tz=self.time_zone)
|
77 |
+
end_date = pd.Timestamp(self.end_date, tz=self.time_zone) + pd.Timedelta(days=1)
|
78 |
+
self.time_interval = time_interval
|
79 |
+
|
80 |
+
date = start_date
|
81 |
+
data_df = pd.DataFrame()
|
82 |
+
while date != end_date:
|
83 |
+
start_time = (date + pd.Timedelta("09:30:00")).isoformat()
|
84 |
+
end_time = (date + pd.Timedelta("15:59:00")).isoformat()
|
85 |
+
for tic in ticker_list:
|
86 |
+
barset = self.api.get_bars(
|
87 |
+
tic,
|
88 |
+
time_interval,
|
89 |
+
start=start_time,
|
90 |
+
end=end_time,
|
91 |
+
limit=500,
|
92 |
+
).df
|
93 |
+
barset["tic"] = tic
|
94 |
+
barset = barset.reset_index()
|
95 |
+
data_df = data_df.append(barset)
|
96 |
+
print(("Data before ") + end_time + " is successfully fetched")
|
97 |
+
# print(data_df.head())
|
98 |
+
date = date + pd.Timedelta(days=1)
|
99 |
+
if date.isoformat()[-14:-6] == "01:00:00":
|
100 |
+
date = date - pd.Timedelta("01:00:00")
|
101 |
+
elif date.isoformat()[-14:-6] == "23:00:00":
|
102 |
+
date = date + pd.Timedelta("01:00:00")
|
103 |
+
if date.isoformat()[-14:-6] != "00:00:00":
|
104 |
+
raise ValueError("Timezone Error")
|
105 |
+
|
106 |
+
data_df["time"] = data_df["timestamp"].apply(
|
107 |
+
lambda x: x.strftime("%Y-%m-%d %H:%M:%S")
|
108 |
+
)
|
109 |
+
self.dataframe = data_df
|
110 |
+
|
111 |
+
self.save_data(save_path)
|
112 |
+
|
113 |
+
print(
|
114 |
+
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
|
115 |
+
)
|
116 |
+
|
117 |
+
def clean_data(self):
|
118 |
+
df = self.dataframe.copy()
|
119 |
+
tic_list = np.unique(df.tic.values)
|
120 |
+
|
121 |
+
trading_days = self.get_trading_days(start=self.start, end=self.end)
|
122 |
+
# produce full time index
|
123 |
+
times = []
|
124 |
+
for day in trading_days:
|
125 |
+
current_time = pd.Timestamp(day + " 09:30:00").tz_localize(self.time_zone)
|
126 |
+
for _ in range(390):
|
127 |
+
times.append(current_time)
|
128 |
+
current_time += pd.Timedelta(minutes=1)
|
129 |
+
# create a new dataframe with full time series
|
130 |
+
new_df = pd.DataFrame()
|
131 |
+
for tic in tic_list:
|
132 |
+
tmp_df = pd.DataFrame(
|
133 |
+
columns=["open", "high", "low", "close", "volume"], index=times
|
134 |
+
)
|
135 |
+
tic_df = df[df.tic == tic]
|
136 |
+
for i in range(tic_df.shape[0]):
|
137 |
+
tmp_df.loc[tic_df.iloc[i]["time"]] = tic_df.iloc[i][
|
138 |
+
["open", "high", "low", "close", "volume"]
|
139 |
+
]
|
140 |
+
|
141 |
+
# if the close price of the first row is NaN
|
142 |
+
if str(tmp_df.iloc[0]["close"]) == "nan":
|
143 |
+
print(
|
144 |
+
"The price of the first row for ticker ",
|
145 |
+
tic,
|
146 |
+
" is NaN. ",
|
147 |
+
"It will filled with the first valid price.",
|
148 |
+
)
|
149 |
+
for i in range(tmp_df.shape[0]):
|
150 |
+
if str(tmp_df.iloc[i]["close"]) != "nan":
|
151 |
+
first_valid_price = tmp_df.iloc[i]["close"]
|
152 |
+
tmp_df.iloc[0] = [
|
153 |
+
first_valid_price,
|
154 |
+
first_valid_price,
|
155 |
+
first_valid_price,
|
156 |
+
first_valid_price,
|
157 |
+
0.0,
|
158 |
+
]
|
159 |
+
break
|
160 |
+
# if the close price of the first row is still NaN (All the prices are NaN in this case)
|
161 |
+
if str(tmp_df.iloc[0]["close"]) == "nan":
|
162 |
+
print(
|
163 |
+
"Missing data for ticker: ",
|
164 |
+
tic,
|
165 |
+
" . The prices are all NaN. Fill with 0.",
|
166 |
+
)
|
167 |
+
tmp_df.iloc[0] = [
|
168 |
+
0.0,
|
169 |
+
0.0,
|
170 |
+
0.0,
|
171 |
+
0.0,
|
172 |
+
0.0,
|
173 |
+
]
|
174 |
+
|
175 |
+
# forward filling row by row
|
176 |
+
for i in range(tmp_df.shape[0]):
|
177 |
+
if str(tmp_df.iloc[i]["close"]) == "nan":
|
178 |
+
previous_close = tmp_df.iloc[i - 1]["close"]
|
179 |
+
if str(previous_close) == "nan":
|
180 |
+
raise ValueError
|
181 |
+
tmp_df.iloc[i] = [
|
182 |
+
previous_close,
|
183 |
+
previous_close,
|
184 |
+
previous_close,
|
185 |
+
previous_close,
|
186 |
+
0.0,
|
187 |
+
]
|
188 |
+
tmp_df = tmp_df.astype(float)
|
189 |
+
tmp_df["tic"] = tic
|
190 |
+
new_df = new_df.append(tmp_df)
|
191 |
+
|
192 |
+
new_df = new_df.reset_index()
|
193 |
+
new_df = new_df.rename(columns={"index": "time"})
|
194 |
+
|
195 |
+
print("Data clean finished!")
|
196 |
+
|
197 |
+
self.dataframe = new_df
|
198 |
+
|
199 |
+
# def add_technical_indicator(
|
200 |
+
# self,
|
201 |
+
# df,
|
202 |
+
# tech_indicator_list=[
|
203 |
+
# "macd",
|
204 |
+
# "boll_ub",
|
205 |
+
# "boll_lb",
|
206 |
+
# "rsi_30",
|
207 |
+
# "dx_30",
|
208 |
+
# "close_30_sma",
|
209 |
+
# "close_60_sma",
|
210 |
+
# ],
|
211 |
+
# ):
|
212 |
+
# df = df.rename(columns={"time": "date"})
|
213 |
+
# df = df.copy()
|
214 |
+
# df = df.sort_values(by=["tic", "date"])
|
215 |
+
# stock = Sdf.retype(df.copy())
|
216 |
+
# unique_ticker = stock.tic.unique()
|
217 |
+
# tech_indicator_list = tech_indicator_list
|
218 |
+
#
|
219 |
+
# for indicator in tech_indicator_list:
|
220 |
+
# indicator_df = pd.DataFrame()
|
221 |
+
# for i in range(len(unique_ticker)):
|
222 |
+
# # print(unique_ticker[i], i)
|
223 |
+
# temp_indicator = stock[stock.tic == unique_ticker[i]][indicator]
|
224 |
+
# temp_indicator = pd.DataFrame(temp_indicator)
|
225 |
+
# temp_indicator["tic"] = unique_ticker[i]
|
226 |
+
# # print(len(df[df.tic == unique_ticker[i]]['date'].to_list()))
|
227 |
+
# temp_indicator["date"] = df[df.tic == unique_ticker[i]][
|
228 |
+
# "date"
|
229 |
+
# ].to_list()
|
230 |
+
# indicator_df = indicator_df.append(temp_indicator, ignore_index=True)
|
231 |
+
# df = df.merge(
|
232 |
+
# indicator_df[["tic", "date", indicator]], on=["tic", "date"], how="left"
|
233 |
+
# )
|
234 |
+
# df = df.sort_values(by=["date", "tic"])
|
235 |
+
# df = df.rename(columns={"date": "time"})
|
236 |
+
# print("Succesfully add technical indicators")
|
237 |
+
# return df
|
238 |
+
|
239 |
+
# def add_vix(self, data):
|
240 |
+
# vix_df = self.download_data(["VIXY"], self.start, self.end, self.time_interval)
|
241 |
+
# cleaned_vix = self.clean_data(vix_df)
|
242 |
+
# vix = cleaned_vix[["time", "close"]]
|
243 |
+
# vix = vix.rename(columns={"close": "VIXY"})
|
244 |
+
#
|
245 |
+
# df = data.copy()
|
246 |
+
# df = df.merge(vix, on="time")
|
247 |
+
# df = df.sort_values(["time", "tic"]).reset_index(drop=True)
|
248 |
+
# return df
|
249 |
+
|
250 |
+
# def calculate_turbulence(self, data, time_period=252):
|
251 |
+
# # can add other market assets
|
252 |
+
# df = data.copy()
|
253 |
+
# df_price_pivot = df.pivot(index="date", columns="tic", values="close")
|
254 |
+
# # use returns to calculate turbulence
|
255 |
+
# df_price_pivot = df_price_pivot.pct_change()
|
256 |
+
#
|
257 |
+
# unique_date = df.date.unique()
|
258 |
+
# # start after a fixed time period
|
259 |
+
# start = time_period
|
260 |
+
# turbulence_index = [0] * start
|
261 |
+
# # turbulence_index = [0]
|
262 |
+
# count = 0
|
263 |
+
# for i in range(start, len(unique_date)):
|
264 |
+
# current_price = df_price_pivot[df_price_pivot.index == unique_date[i]]
|
265 |
+
# # use one year rolling window to calcualte covariance
|
266 |
+
# hist_price = df_price_pivot[
|
267 |
+
# (df_price_pivot.index < unique_date[i])
|
268 |
+
# & (df_price_pivot.index >= unique_date[i - time_period])
|
269 |
+
# ]
|
270 |
+
# # Drop tickers which has number missing values more than the "oldest" ticker
|
271 |
+
# filtered_hist_price = hist_price.iloc[
|
272 |
+
# hist_price.isna().sum().min() :
|
273 |
+
# ].dropna(axis=1)
|
274 |
+
#
|
275 |
+
# cov_temp = filtered_hist_price.cov()
|
276 |
+
# current_temp = current_price[[x for x in filtered_hist_price]] - np.mean(
|
277 |
+
# filtered_hist_price, axis=0
|
278 |
+
# )
|
279 |
+
# temp = current_temp.values.dot(np.linalg.pinv(cov_temp)).dot(
|
280 |
+
# current_temp.values.T
|
281 |
+
# )
|
282 |
+
# if temp > 0:
|
283 |
+
# count += 1
|
284 |
+
# if count > 2:
|
285 |
+
# turbulence_temp = temp[0][0]
|
286 |
+
# else:
|
287 |
+
# # avoid large outlier because of the calculation just begins
|
288 |
+
# turbulence_temp = 0
|
289 |
+
# else:
|
290 |
+
# turbulence_temp = 0
|
291 |
+
# turbulence_index.append(turbulence_temp)
|
292 |
+
#
|
293 |
+
# turbulence_index = pd.DataFrame(
|
294 |
+
# {"date": df_price_pivot.index, "turbulence": turbulence_index}
|
295 |
+
# )
|
296 |
+
# return turbulence_index
|
297 |
+
#
|
298 |
+
# def add_turbulence(self, data, time_period=252):
|
299 |
+
# """
|
300 |
+
# add turbulence index from a precalcualted dataframe
|
301 |
+
# :param data: (df) pandas dataframe
|
302 |
+
# :return: (df) pandas dataframe
|
303 |
+
# """
|
304 |
+
# df = data.copy()
|
305 |
+
# turbulence_index = self.calculate_turbulence(df, time_period=time_period)
|
306 |
+
# df = df.merge(turbulence_index, on="date")
|
307 |
+
# df = df.sort_values(["date", "tic"]).reset_index(drop=True)
|
308 |
+
# return df
|
309 |
+
|
310 |
+
# def df_to_array(self, df, tech_indicator_list, if_vix):
|
311 |
+
# df = df.copy()
|
312 |
+
# unique_ticker = df.tic.unique()
|
313 |
+
# if_first_time = True
|
314 |
+
# for tic in unique_ticker:
|
315 |
+
# if if_first_time:
|
316 |
+
# price_array = df[df.tic == tic][["close"]].values
|
317 |
+
# tech_array = df[df.tic == tic][tech_indicator_list].values
|
318 |
+
# if if_vix:
|
319 |
+
# turbulence_array = df[df.tic == tic]["VIXY"].values
|
320 |
+
# else:
|
321 |
+
# turbulence_array = df[df.tic == tic]["turbulence"].values
|
322 |
+
# if_first_time = False
|
323 |
+
# else:
|
324 |
+
# price_array = np.hstack(
|
325 |
+
# [price_array, df[df.tic == tic][["close"]].values]
|
326 |
+
# )
|
327 |
+
# tech_array = np.hstack(
|
328 |
+
# [tech_array, df[df.tic == tic][tech_indicator_list].values]
|
329 |
+
# )
|
330 |
+
# print("Successfully transformed into array")
|
331 |
+
# return price_array, tech_array, turbulence_array
|
332 |
+
|
333 |
+
def get_trading_days(self, start, end):
|
334 |
+
nyse = tc.get_calendar("NYSE")
|
335 |
+
df = nyse.sessions_in_range(
|
336 |
+
pd.Timestamp(start, tz=pytz.UTC), pd.Timestamp(end, tz=pytz.UTC)
|
337 |
+
)
|
338 |
+
return [str(day)[:10] for day in df]
|
339 |
+
|
340 |
+
def fetch_latest_data(
|
341 |
+
self, ticker_list, time_interval, tech_indicator_list, limit=100
|
342 |
+
) -> pd.DataFrame:
|
343 |
+
data_df = pd.DataFrame()
|
344 |
+
for tic in ticker_list:
|
345 |
+
barset = self.api.get_barset([tic], time_interval, limit=limit).df[tic]
|
346 |
+
barset["tic"] = tic
|
347 |
+
barset = barset.reset_index()
|
348 |
+
data_df = data_df.append(barset)
|
349 |
+
|
350 |
+
data_df = data_df.reset_index(drop=True)
|
351 |
+
start_time = data_df.time.min()
|
352 |
+
end_time = data_df.time.max()
|
353 |
+
times = []
|
354 |
+
current_time = start_time
|
355 |
+
end = end_time + pd.Timedelta(minutes=1)
|
356 |
+
while current_time != end:
|
357 |
+
times.append(current_time)
|
358 |
+
current_time += pd.Timedelta(minutes=1)
|
359 |
+
|
360 |
+
df = data_df.copy()
|
361 |
+
new_df = pd.DataFrame()
|
362 |
+
for tic in ticker_list:
|
363 |
+
tmp_df = pd.DataFrame(
|
364 |
+
columns=["open", "high", "low", "close", "volume"], index=times
|
365 |
+
)
|
366 |
+
tic_df = df[df.tic == tic]
|
367 |
+
for i in range(tic_df.shape[0]):
|
368 |
+
tmp_df.loc[tic_df.iloc[i]["time"]] = tic_df.iloc[i][
|
369 |
+
["open", "high", "low", "close", "volume"]
|
370 |
+
]
|
371 |
+
|
372 |
+
if str(tmp_df.iloc[0]["close"]) == "nan":
|
373 |
+
for i in range(tmp_df.shape[0]):
|
374 |
+
if str(tmp_df.iloc[i]["close"]) != "nan":
|
375 |
+
first_valid_close = tmp_df.iloc[i]["close"]
|
376 |
+
tmp_df.iloc[0] = [
|
377 |
+
first_valid_close,
|
378 |
+
first_valid_close,
|
379 |
+
first_valid_close,
|
380 |
+
first_valid_close,
|
381 |
+
0.0,
|
382 |
+
]
|
383 |
+
break
|
384 |
+
if str(tmp_df.iloc[0]["close"]) == "nan":
|
385 |
+
print(
|
386 |
+
"Missing data for ticker: ",
|
387 |
+
tic,
|
388 |
+
" . The prices are all NaN. Fill with 0.",
|
389 |
+
)
|
390 |
+
tmp_df.iloc[0] = [
|
391 |
+
0.0,
|
392 |
+
0.0,
|
393 |
+
0.0,
|
394 |
+
0.0,
|
395 |
+
0.0,
|
396 |
+
]
|
397 |
+
|
398 |
+
for i in range(tmp_df.shape[0]):
|
399 |
+
if str(tmp_df.iloc[i]["close"]) == "nan":
|
400 |
+
previous_close = tmp_df.iloc[i - 1]["close"]
|
401 |
+
if str(previous_close) == "nan":
|
402 |
+
raise ValueError
|
403 |
+
tmp_df.iloc[i] = [
|
404 |
+
previous_close,
|
405 |
+
previous_close,
|
406 |
+
previous_close,
|
407 |
+
previous_close,
|
408 |
+
0.0,
|
409 |
+
]
|
410 |
+
tmp_df = tmp_df.astype(float)
|
411 |
+
tmp_df["tic"] = tic
|
412 |
+
new_df = new_df.append(tmp_df)
|
413 |
+
|
414 |
+
new_df = new_df.reset_index()
|
415 |
+
new_df = new_df.rename(columns={"index": "time"})
|
416 |
+
|
417 |
+
df = self.add_technical_indicator(new_df, tech_indicator_list)
|
418 |
+
df["VIXY"] = 0
|
419 |
+
|
420 |
+
price_array, tech_array, turbulence_array = self.df_to_array(
|
421 |
+
df, tech_indicator_list, if_vix=True
|
422 |
+
)
|
423 |
+
latest_price = price_array[-1]
|
424 |
+
latest_tech = tech_array[-1]
|
425 |
+
turb_df = self.api.get_barset(["VIXY"], time_interval, limit=1).df["VIXY"]
|
426 |
+
latest_turb = turb_df["close"].values
|
427 |
+
return latest_price, latest_tech, latest_turb
|
428 |
+
|
429 |
+
def get_portfolio_history(self, start, end):
|
430 |
+
trading_days = self.get_trading_days(start, end)
|
431 |
+
df = pd.DataFrame()
|
432 |
+
for day in trading_days:
|
433 |
+
df = df.append(
|
434 |
+
self.api.get_portfolio_history(
|
435 |
+
date_start=day, timeframe="5Min"
|
436 |
+
).df.iloc[:79]
|
437 |
+
)
|
438 |
+
equities = df.equity.values
|
439 |
+
cumu_returns = equities / equities[0]
|
440 |
+
cumu_returns = cumu_returns[~np.isnan(cumu_returns)]
|
441 |
+
return cumu_returns
|
finnlp/data_processors/alphavantage.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import json
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import requests
|
7 |
+
|
8 |
+
from meta.config import BINANCE_BASE_URL
|
9 |
+
from meta.config import TIME_ZONE_BERLIN
|
10 |
+
from meta.config import TIME_ZONE_JAKARTA
|
11 |
+
from meta.config import TIME_ZONE_PARIS
|
12 |
+
from meta.config import TIME_ZONE_SELFDEFINED
|
13 |
+
from meta.config import TIME_ZONE_SHANGHAI
|
14 |
+
from meta.config import TIME_ZONE_USEASTERN
|
15 |
+
from meta.config import USE_TIME_ZONE_SELFDEFINED
|
16 |
+
from meta.data_processors._base import _Base
|
17 |
+
from meta.data_processors._base import calc_time_zone
|
18 |
+
|
19 |
+
|
20 |
+
def transfer_date(d):
|
21 |
+
date = str(d.year)
|
22 |
+
date += "-"
|
23 |
+
if len(str(d.month)) == 1:
|
24 |
+
date += "0" + str(d.month)
|
25 |
+
else:
|
26 |
+
date += d.month
|
27 |
+
date += "-"
|
28 |
+
if len(str(d.day)) == 1:
|
29 |
+
date += "0" + str(d.day)
|
30 |
+
else:
|
31 |
+
date += str(d.day)
|
32 |
+
return date
|
33 |
+
|
34 |
+
|
35 |
+
class Alphavantage(_Base):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
data_source: str,
|
39 |
+
start_date: str,
|
40 |
+
end_date: str,
|
41 |
+
time_interval: str,
|
42 |
+
**kwargs,
|
43 |
+
):
|
44 |
+
super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
|
45 |
+
|
46 |
+
assert time_interval == "1d", "please set the time_interval 1d"
|
47 |
+
|
48 |
+
# ["1d"]
|
49 |
+
def download_data(
|
50 |
+
self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
|
51 |
+
):
|
52 |
+
# self.time_zone = calc_time_zone(
|
53 |
+
# ticker_list, TIME_ZONE_SELFDEFINED, USE_TIME_ZONE_SELFDEFINED
|
54 |
+
# )
|
55 |
+
self.dataframe = pd.DataFrame()
|
56 |
+
for ticker in ticker_list:
|
57 |
+
url = (
|
58 |
+
"https://www.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol="
|
59 |
+
+ ticker
|
60 |
+
+ "&apikey=demo"
|
61 |
+
)
|
62 |
+
r = requests.get(url)
|
63 |
+
data = r.json()
|
64 |
+
data2 = json.dumps(data["Time Series (Daily)"])
|
65 |
+
# gnData = json.dumps(data["Data"]["gn"])
|
66 |
+
df2 = pd.read_json(data2)
|
67 |
+
# gnDf = pd.read_json(gnData)
|
68 |
+
|
69 |
+
df3 = pd.DataFrame(df2.values.T, columns=df2.index, index=df2.columns)
|
70 |
+
df3.rename(
|
71 |
+
columns={
|
72 |
+
"1. open": "open",
|
73 |
+
"2. high": "high",
|
74 |
+
"3. low": "low",
|
75 |
+
"4. close": "close",
|
76 |
+
"5. volume": "volume",
|
77 |
+
},
|
78 |
+
inplace=True,
|
79 |
+
)
|
80 |
+
df3["tic"] = ticker
|
81 |
+
dates = [transfer_date(df2.index[i]) for i in range(len(df2.index))]
|
82 |
+
df3["date"] = dates
|
83 |
+
self.dataframe = pd.concat([self.dataframe, df3])
|
84 |
+
self.dataframe = self.dataframe.sort_values(by=["date", "tic"]).reset_index(
|
85 |
+
drop=True
|
86 |
+
)
|
87 |
+
|
88 |
+
self.save_data(save_path)
|
89 |
+
|
90 |
+
print(
|
91 |
+
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
|
92 |
+
)
|
finnlp/data_processors/baostock.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import baostock as bs
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import pytz
|
7 |
+
import yfinance as yf
|
8 |
+
|
9 |
+
"""Reference: https://github.com/AI4Finance-LLC/FinRL"""
|
10 |
+
|
11 |
+
try:
|
12 |
+
import exchange_calendars as tc
|
13 |
+
except:
|
14 |
+
print(
|
15 |
+
"Cannot import exchange_calendars.",
|
16 |
+
"If you are using python>=3.7, please install it.",
|
17 |
+
)
|
18 |
+
import trading_calendars as tc
|
19 |
+
|
20 |
+
print("Use trading_calendars instead for yahoofinance processor..")
|
21 |
+
# from basic_processor import _Base
|
22 |
+
from meta.data_processors._base import _Base
|
23 |
+
from meta.data_processors._base import calc_time_zone
|
24 |
+
|
25 |
+
from meta.config import (
|
26 |
+
TIME_ZONE_SHANGHAI,
|
27 |
+
TIME_ZONE_USEASTERN,
|
28 |
+
TIME_ZONE_PARIS,
|
29 |
+
TIME_ZONE_BERLIN,
|
30 |
+
TIME_ZONE_JAKARTA,
|
31 |
+
TIME_ZONE_SELFDEFINED,
|
32 |
+
USE_TIME_ZONE_SELFDEFINED,
|
33 |
+
BINANCE_BASE_URL,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
class Baostock(_Base):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
data_source: str,
|
41 |
+
start_date: str,
|
42 |
+
end_date: str,
|
43 |
+
time_interval: str,
|
44 |
+
**kwargs,
|
45 |
+
):
|
46 |
+
super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
|
47 |
+
|
48 |
+
# 日k线、周k线、月k线,以及5分钟、15分钟、30分钟和60分钟k线数据
|
49 |
+
# ["5m", "15m", "30m", "60m", "1d", "1w", "1M"]
|
50 |
+
def download_data(
|
51 |
+
self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
|
52 |
+
):
|
53 |
+
lg = bs.login()
|
54 |
+
print("baostock login respond error_code:" + lg.error_code)
|
55 |
+
print("baostock login respond error_msg:" + lg.error_msg)
|
56 |
+
|
57 |
+
self.time_zone = calc_time_zone(
|
58 |
+
ticker_list, TIME_ZONE_SELFDEFINED, USE_TIME_ZONE_SELFDEFINED
|
59 |
+
)
|
60 |
+
self.dataframe = pd.DataFrame()
|
61 |
+
for ticker in ticker_list:
|
62 |
+
nonstandrad_ticker = self.transfer_standard_ticker_to_nonstandard(ticker)
|
63 |
+
# All supported: "date,code,open,high,low,close,preclose,volume,amount,adjustflag,turn,tradestatus,pctChg,isST"
|
64 |
+
rs = bs.query_history_k_data_plus(
|
65 |
+
nonstandrad_ticker,
|
66 |
+
"date,code,open,high,low,close,volume",
|
67 |
+
start_date=self.start_date,
|
68 |
+
end_date=self.end_date,
|
69 |
+
frequency=self.time_interval,
|
70 |
+
adjustflag="3",
|
71 |
+
)
|
72 |
+
|
73 |
+
print("baostock download_data respond error_code:" + rs.error_code)
|
74 |
+
print("baostock download_data respond error_msg:" + rs.error_msg)
|
75 |
+
|
76 |
+
data_list = []
|
77 |
+
while (rs.error_code == "0") & rs.next():
|
78 |
+
data_list.append(rs.get_row_data())
|
79 |
+
df = pd.DataFrame(data_list, columns=rs.fields)
|
80 |
+
df.loc[:, "code"] = [ticker] * df.shape[0]
|
81 |
+
self.dataframe = pd.concat([self.dataframe, df])
|
82 |
+
self.dataframe = self.dataframe.sort_values(by=["date", "code"]).reset_index(
|
83 |
+
drop=True
|
84 |
+
)
|
85 |
+
bs.logout()
|
86 |
+
|
87 |
+
self.dataframe.open = self.dataframe.open.astype(float)
|
88 |
+
self.dataframe.high = self.dataframe.high.astype(float)
|
89 |
+
self.dataframe.low = self.dataframe.low.astype(float)
|
90 |
+
self.dataframe.close = self.dataframe.close.astype(float)
|
91 |
+
self.save_data(save_path)
|
92 |
+
|
93 |
+
print(
|
94 |
+
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
|
95 |
+
)
|
96 |
+
|
97 |
+
def get_trading_days(self, start, end):
|
98 |
+
lg = bs.login()
|
99 |
+
print("baostock login respond error_code:" + lg.error_code)
|
100 |
+
print("baostock login respond error_msg:" + lg.error_msg)
|
101 |
+
result = bs.query_trade_dates(start_date=start, end_date=end)
|
102 |
+
bs.logout()
|
103 |
+
return result
|
104 |
+
|
105 |
+
# "600000.XSHG" -> "sh.600000"
|
106 |
+
# "000612.XSHE" -> "sz.000612"
|
107 |
+
def transfer_standard_ticker_to_nonstandard(self, ticker: str) -> str:
|
108 |
+
n, alpha = ticker.split(".")
|
109 |
+
assert alpha in ["XSHG", "XSHE"], "Wrong alpha"
|
110 |
+
if alpha == "XSHG":
|
111 |
+
nonstandard_ticker = "sh." + n
|
112 |
+
elif alpha == "XSHE":
|
113 |
+
nonstandard_ticker = "sz." + n
|
114 |
+
return nonstandard_ticker
|
finnlp/data_processors/binance.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime as dt
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import urllib
|
5 |
+
import zipfile
|
6 |
+
from datetime import *
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import List
|
9 |
+
|
10 |
+
import pandas as pd
|
11 |
+
import requests
|
12 |
+
|
13 |
+
from meta.config import BINANCE_BASE_URL
|
14 |
+
from meta.config import TIME_ZONE_BERLIN
|
15 |
+
from meta.config import TIME_ZONE_JAKARTA
|
16 |
+
from meta.config import TIME_ZONE_PARIS
|
17 |
+
from meta.config import TIME_ZONE_SELFDEFINED
|
18 |
+
from meta.config import TIME_ZONE_SHANGHAI
|
19 |
+
from meta.config import TIME_ZONE_USEASTERN
|
20 |
+
from meta.config import USE_TIME_ZONE_SELFDEFINED
|
21 |
+
from meta.data_processors._base import _Base
|
22 |
+
from meta.data_processors._base import check_date
|
23 |
+
|
24 |
+
# from _base import check_date
|
25 |
+
|
26 |
+
|
27 |
+
class Binance(_Base):
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
data_source: str,
|
31 |
+
start_date: str,
|
32 |
+
end_date: str,
|
33 |
+
time_interval: str,
|
34 |
+
**kwargs,
|
35 |
+
):
|
36 |
+
if time_interval == "1D":
|
37 |
+
raise ValueError("Please use the time_interval 1d instead of 1D")
|
38 |
+
if time_interval == "1d":
|
39 |
+
check_date(start_date)
|
40 |
+
check_date(end_date)
|
41 |
+
super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
|
42 |
+
self.url = "https://api.binance.com/api/v3/klines"
|
43 |
+
self.time_diff = None
|
44 |
+
|
45 |
+
# main functions
|
46 |
+
def download_data(
|
47 |
+
self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
|
48 |
+
):
|
49 |
+
startTime = dt.datetime.strptime(self.start_date, "%Y-%m-%d")
|
50 |
+
endTime = dt.datetime.strptime(self.end_date, "%Y-%m-%d")
|
51 |
+
|
52 |
+
self.start_time = self.stringify_dates(startTime)
|
53 |
+
self.end_time = self.stringify_dates(endTime)
|
54 |
+
self.interval = self.time_interval
|
55 |
+
self.limit = 1440
|
56 |
+
|
57 |
+
# 1s for now, will add support for variable time and variable tick soon
|
58 |
+
if self.time_interval == "1s":
|
59 |
+
# as per https://binance-docs.github.io/apidocs/spot/en/#compressed-aggregate-trades-list
|
60 |
+
self.limit = 1000
|
61 |
+
final_df = self.fetch_n_combine(self.start_date, self.end_date, ticker_list)
|
62 |
+
else:
|
63 |
+
final_df = pd.DataFrame()
|
64 |
+
for i in ticker_list:
|
65 |
+
hist_data = self.dataframe_with_limit(symbol=i)
|
66 |
+
df = hist_data.iloc[:-1].dropna()
|
67 |
+
df["tic"] = i
|
68 |
+
final_df = pd.concat([final_df, df], axis=0, join="outer")
|
69 |
+
self.dataframe = final_df
|
70 |
+
|
71 |
+
self.save_data(save_path)
|
72 |
+
|
73 |
+
print(
|
74 |
+
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
|
75 |
+
)
|
76 |
+
|
77 |
+
# def clean_data(self, df):
|
78 |
+
# df = df.dropna()
|
79 |
+
# return df
|
80 |
+
|
81 |
+
# def add_technical_indicator(self, df, tech_indicator_list):
|
82 |
+
# print('Adding self-defined technical indicators is NOT supported yet.')
|
83 |
+
# print('Use default: MACD, RSI, CCI, DX.')
|
84 |
+
# self.tech_indicator_list = ['open', 'high', 'low', 'close', 'volume',
|
85 |
+
# 'macd', 'macd_signal', 'macd_hist',
|
86 |
+
# 'rsi', 'cci', 'dx']
|
87 |
+
# final_df = pd.DataFrame()
|
88 |
+
# for i in df.tic.unique():
|
89 |
+
# tic_df = df[df.tic==i]
|
90 |
+
# tic_df['macd'], tic_df['macd_signal'], tic_df['macd_hist'] = MACD(tic_df['close'], fastperiod=12,
|
91 |
+
# slowperiod=26, signalperiod=9)
|
92 |
+
# tic_df['rsi'] = RSI(tic_df['close'], timeperiod=14)
|
93 |
+
# tic_df['cci'] = CCI(tic_df['high'], tic_df['low'], tic_df['close'], timeperiod=14)
|
94 |
+
# tic_df['dx'] = DX(tic_df['high'], tic_df['low'], tic_df['close'], timeperiod=14)
|
95 |
+
# final_df = final_df.append(tic_df)
|
96 |
+
#
|
97 |
+
# return final_df
|
98 |
+
|
99 |
+
# def add_turbulence(self, df):
|
100 |
+
# print('Turbulence not supported yet. Return original DataFrame.')
|
101 |
+
#
|
102 |
+
# return df
|
103 |
+
|
104 |
+
# def add_vix(self, df):
|
105 |
+
# print('VIX is not applicable for cryptocurrencies. Return original DataFrame')
|
106 |
+
#
|
107 |
+
# return df
|
108 |
+
|
109 |
+
# def df_to_array(self, df, tech_indicator_list, if_vix):
|
110 |
+
# unique_ticker = df.tic.unique()
|
111 |
+
# price_array = np.column_stack([df[df.tic==tic].close for tic in unique_ticker])
|
112 |
+
# tech_array = np.hstack([df.loc[(df.tic==tic), tech_indicator_list] for tic in unique_ticker])
|
113 |
+
# assert price_array.shape[0] == tech_array.shape[0]
|
114 |
+
# return price_array, tech_array, np.array([])
|
115 |
+
|
116 |
+
# helper functions
|
117 |
+
def stringify_dates(self, date: dt.datetime):
|
118 |
+
return str(int(date.timestamp() * 1000))
|
119 |
+
|
120 |
+
def get_binance_bars(self, last_datetime, symbol):
|
121 |
+
"""
|
122 |
+
klines api returns data in the following order:
|
123 |
+
open_time, open_price, high_price, low_price, close_price,
|
124 |
+
volume, close_time, quote_asset_volume, n_trades,
|
125 |
+
taker_buy_base_asset_volume, taker_buy_quote_asset_volume,
|
126 |
+
ignore
|
127 |
+
"""
|
128 |
+
req_params = {
|
129 |
+
"symbol": symbol,
|
130 |
+
"interval": self.interval,
|
131 |
+
"startTime": last_datetime,
|
132 |
+
"endTime": self.end_time,
|
133 |
+
"limit": self.limit,
|
134 |
+
}
|
135 |
+
# For debugging purposes, uncomment these lines and if they throw an error
|
136 |
+
# then you may have an error in req_params
|
137 |
+
# r = requests.get(self.url, params=req_params)
|
138 |
+
# print(r.text)
|
139 |
+
df = pd.DataFrame(requests.get(self.url, params=req_params).json())
|
140 |
+
|
141 |
+
if df.empty:
|
142 |
+
return None
|
143 |
+
|
144 |
+
df = df.iloc[:, 0:6]
|
145 |
+
df.columns = ["datetime", "open", "high", "low", "close", "volume"]
|
146 |
+
|
147 |
+
df[["open", "high", "low", "close", "volume"]] = df[
|
148 |
+
["open", "high", "low", "close", "volume"]
|
149 |
+
].astype(float)
|
150 |
+
|
151 |
+
# No stock split and dividend announcement, hence adjusted close is the same as close
|
152 |
+
df["adjusted_close"] = df["close"]
|
153 |
+
df["datetime"] = df.datetime.apply(
|
154 |
+
lambda x: dt.datetime.fromtimestamp(x / 1000.0)
|
155 |
+
)
|
156 |
+
df.reset_index(drop=True, inplace=True)
|
157 |
+
|
158 |
+
return df
|
159 |
+
|
160 |
+
def get_newest_bars(self, symbols, interval, limit):
|
161 |
+
merged_df = pd.DataFrame()
|
162 |
+
for symbol in symbols:
|
163 |
+
req_params = {
|
164 |
+
"symbol": symbol,
|
165 |
+
"interval": interval,
|
166 |
+
"limit": limit,
|
167 |
+
}
|
168 |
+
|
169 |
+
df = pd.DataFrame(
|
170 |
+
requests.get(self.url, params=req_params).json(),
|
171 |
+
index=range(limit),
|
172 |
+
)
|
173 |
+
|
174 |
+
if df.empty:
|
175 |
+
return None
|
176 |
+
|
177 |
+
df = df.iloc[:, 0:6]
|
178 |
+
df.columns = ["datetime", "open", "high", "low", "close", "volume"]
|
179 |
+
|
180 |
+
df[["open", "high", "low", "close", "volume"]] = df[
|
181 |
+
["open", "high", "low", "close", "volume"]
|
182 |
+
].astype(float)
|
183 |
+
|
184 |
+
# No stock split and dividend announcement, hence adjusted close is the same as close
|
185 |
+
df["adjusted_close"] = df["close"]
|
186 |
+
df["datetime"] = df.datetime.apply(
|
187 |
+
lambda x: dt.datetime.fromtimestamp(x / 1000.0)
|
188 |
+
)
|
189 |
+
df["tic"] = symbol
|
190 |
+
df = df.rename(columns={"datetime": "time"})
|
191 |
+
df.reset_index(drop=True, inplace=True)
|
192 |
+
merged_df = merged_df.append(df)
|
193 |
+
|
194 |
+
return merged_df
|
195 |
+
|
196 |
+
def dataframe_with_limit(self, symbol):
|
197 |
+
final_df = pd.DataFrame()
|
198 |
+
last_datetime = self.start_time
|
199 |
+
while True:
|
200 |
+
new_df = self.get_binance_bars(last_datetime, symbol)
|
201 |
+
if new_df is None:
|
202 |
+
break
|
203 |
+
|
204 |
+
if last_datetime == self.end_time:
|
205 |
+
break
|
206 |
+
|
207 |
+
final_df = pd.concat([final_df, new_df], axis=0, join="outer")
|
208 |
+
# last_datetime = max(new_df.datetime) + dt.timedelta(days=1)
|
209 |
+
last_datetime = max(new_df.datetime)
|
210 |
+
if isinstance(last_datetime, pd.Timestamp):
|
211 |
+
last_datetime = last_datetime.to_pydatetime()
|
212 |
+
|
213 |
+
if self.time_diff == None:
|
214 |
+
self.time_diff = new_df.loc[1]["datetime"] - new_df.loc[0]["datetime"]
|
215 |
+
|
216 |
+
last_datetime = last_datetime + self.time_diff
|
217 |
+
last_datetime = self.stringify_dates(last_datetime)
|
218 |
+
|
219 |
+
date_value = final_df["datetime"].apply(
|
220 |
+
lambda x: x.strftime("%Y-%m-%d %H:%M:%S")
|
221 |
+
)
|
222 |
+
final_df.insert(0, "time", date_value)
|
223 |
+
final_df.drop("datetime", inplace=True, axis=1)
|
224 |
+
return final_df
|
225 |
+
|
226 |
+
def get_download_url(self, file_url):
|
227 |
+
return f"{BINANCE_BASE_URL}{file_url}"
|
228 |
+
|
229 |
+
# downloads zip, unzips zip and deltes zip
|
230 |
+
def download_n_unzip_file(self, base_path, file_name, date_range=None):
|
231 |
+
download_path = f"{base_path}{file_name}"
|
232 |
+
if date_range:
|
233 |
+
date_range = date_range.replace(" ", "_")
|
234 |
+
base_path = os.path.join(base_path, date_range)
|
235 |
+
|
236 |
+
# raw_cache_dir = get_destination_dir("./cache/tick_raw")
|
237 |
+
raw_cache_dir = "./cache/tick_raw"
|
238 |
+
zip_save_path = os.path.join(raw_cache_dir, file_name)
|
239 |
+
|
240 |
+
csv_name = os.path.splitext(file_name)[0] + ".csv"
|
241 |
+
csv_save_path = os.path.join(raw_cache_dir, csv_name)
|
242 |
+
|
243 |
+
fhandles = []
|
244 |
+
|
245 |
+
if os.path.exists(csv_save_path):
|
246 |
+
print(f"\nfile already exists! {csv_save_path}")
|
247 |
+
return [csv_save_path]
|
248 |
+
|
249 |
+
# make the "cache" directory (only)
|
250 |
+
if not os.path.exists(raw_cache_dir):
|
251 |
+
Path(raw_cache_dir).mkdir(parents=True, exist_ok=True)
|
252 |
+
|
253 |
+
try:
|
254 |
+
download_url = self.get_download_url(download_path)
|
255 |
+
dl_file = urllib.request.urlopen(download_url)
|
256 |
+
length = dl_file.getheader("content-length")
|
257 |
+
if length:
|
258 |
+
length = int(length)
|
259 |
+
blocksize = max(4096, length // 100)
|
260 |
+
|
261 |
+
with open(zip_save_path, "wb") as out_file:
|
262 |
+
dl_progress = 0
|
263 |
+
print(f"\nFile Download: {zip_save_path}")
|
264 |
+
while True:
|
265 |
+
buf = dl_file.read(blocksize)
|
266 |
+
if not buf:
|
267 |
+
break
|
268 |
+
out_file.write(buf)
|
269 |
+
# visuals
|
270 |
+
# dl_progress += len(buf)
|
271 |
+
# done = int(50 * dl_progress / length)
|
272 |
+
# sys.stdout.write("\r[%s%s]" % ('#' * done, '.' * (50-done)) )
|
273 |
+
# sys.stdout.flush()
|
274 |
+
|
275 |
+
# unzip and delete zip
|
276 |
+
file = zipfile.ZipFile(zip_save_path)
|
277 |
+
with zipfile.ZipFile(zip_save_path) as zip:
|
278 |
+
# guaranteed just 1 csv
|
279 |
+
csvpath = zip.extract(zip.namelist()[0], raw_cache_dir)
|
280 |
+
fhandles.append(csvpath)
|
281 |
+
os.remove(zip_save_path)
|
282 |
+
return fhandles
|
283 |
+
|
284 |
+
except urllib.error.HTTPError:
|
285 |
+
print(f"\nFile not found: {download_url}")
|
286 |
+
|
287 |
+
def convert_to_date_object(self, d):
|
288 |
+
year, month, day = [int(x) for x in d.split("-")]
|
289 |
+
return date(year, month, day)
|
290 |
+
|
291 |
+
def get_path(
|
292 |
+
self,
|
293 |
+
trading_type,
|
294 |
+
market_data_type,
|
295 |
+
time_period,
|
296 |
+
symbol,
|
297 |
+
interval=None,
|
298 |
+
):
|
299 |
+
trading_type_path = "data/spot"
|
300 |
+
# currently just supporting spot
|
301 |
+
if trading_type != "spot":
|
302 |
+
trading_type_path = f"data/futures/{trading_type}"
|
303 |
+
return (
|
304 |
+
f"{trading_type_path}/{time_period}/{market_data_type}/{symbol.upper()}/{interval}/"
|
305 |
+
if interval is not None
|
306 |
+
else f"{trading_type_path}/{time_period}/{market_data_type}/{symbol.upper()}/"
|
307 |
+
)
|
308 |
+
|
309 |
+
# helpers for manipulating tick level data (1s intervals)
|
310 |
+
def download_daily_aggTrades(
|
311 |
+
self, symbols, num_symbols, dates, start_date, end_date
|
312 |
+
):
|
313 |
+
trading_type = "spot"
|
314 |
+
date_range = start_date + " " + end_date
|
315 |
+
start_date = self.convert_to_date_object(start_date)
|
316 |
+
end_date = self.convert_to_date_object(end_date)
|
317 |
+
|
318 |
+
print(f"Found {num_symbols} symbols")
|
319 |
+
|
320 |
+
map = {}
|
321 |
+
for current, symbol in enumerate(symbols):
|
322 |
+
map[symbol] = []
|
323 |
+
print(
|
324 |
+
f"[{current + 1}/{num_symbols}] - start download daily {symbol} aggTrades "
|
325 |
+
)
|
326 |
+
for date in dates:
|
327 |
+
current_date = self.convert_to_date_object(date)
|
328 |
+
if current_date >= start_date and current_date <= end_date:
|
329 |
+
path = self.get_path(trading_type, "aggTrades", "daily", symbol)
|
330 |
+
file_name = f"{symbol.upper()}-aggTrades-{date}.zip"
|
331 |
+
fhandle = self.download_n_unzip_file(path, file_name, date_range)
|
332 |
+
map[symbol] += fhandle
|
333 |
+
return map
|
334 |
+
|
335 |
+
def fetch_aggTrades(self, startDate: str, endDate: str, tickers: List[str]):
|
336 |
+
# all valid symbols traded on v3 api
|
337 |
+
response = urllib.request.urlopen(
|
338 |
+
"https://api.binance.com/api/v3/exchangeInfo"
|
339 |
+
).read()
|
340 |
+
valid_symbols = list(
|
341 |
+
map(
|
342 |
+
lambda symbol: symbol["symbol"],
|
343 |
+
json.loads(response)["symbols"],
|
344 |
+
)
|
345 |
+
)
|
346 |
+
|
347 |
+
for tic in tickers:
|
348 |
+
if tic not in valid_symbols:
|
349 |
+
print(tic + " not a valid ticker, removing from download")
|
350 |
+
tickers = list(set(tickers) & set(valid_symbols))
|
351 |
+
num_symbols = len(tickers)
|
352 |
+
# not adding tz yet
|
353 |
+
# for ffill missing data on starting on first day 00:00:00 (if any)
|
354 |
+
tminus1 = (self.convert_to_date_object(startDate) - dt.timedelta(1)).strftime(
|
355 |
+
"%Y-%m-%d"
|
356 |
+
)
|
357 |
+
dates = pd.date_range(start=tminus1, end=endDate)
|
358 |
+
dates = [date.strftime("%Y-%m-%d") for date in dates]
|
359 |
+
return self.download_daily_aggTrades(
|
360 |
+
tickers, num_symbols, dates, tminus1, endDate
|
361 |
+
)
|
362 |
+
|
363 |
+
# Dict[str]:List[str] -> pd.DataFrame
|
364 |
+
def combine_raw(self, map):
|
365 |
+
# same format as jingyang's current data format
|
366 |
+
final_df = pd.DataFrame()
|
367 |
+
# using AggTrades with headers from https://github.com/binance/binance-public-data/
|
368 |
+
colNames = [
|
369 |
+
"AggregatetradeId",
|
370 |
+
"Price",
|
371 |
+
"volume",
|
372 |
+
"FirsttradeId",
|
373 |
+
"LasttradeId",
|
374 |
+
"time",
|
375 |
+
"buyerWasMaker",
|
376 |
+
"tradeWasBestPriceMatch",
|
377 |
+
]
|
378 |
+
for tic in map.keys():
|
379 |
+
security = pd.DataFrame()
|
380 |
+
for i, csv in enumerate(map[tic]):
|
381 |
+
dailyticks = pd.read_csv(
|
382 |
+
csv,
|
383 |
+
names=colNames,
|
384 |
+
index_col=["time"],
|
385 |
+
parse_dates=["time"],
|
386 |
+
date_parser=lambda epoch: pd.to_datetime(epoch, unit="ms"),
|
387 |
+
)
|
388 |
+
dailyfinal = dailyticks.resample("1s").agg(
|
389 |
+
{"Price": "ohlc", "volume": "sum"}
|
390 |
+
)
|
391 |
+
dailyfinal.columns = dailyfinal.columns.droplevel(0)
|
392 |
+
# favor continuous series
|
393 |
+
# dailyfinal.dropna(inplace=True)
|
394 |
+
|
395 |
+
# implemented T-1 day ffill day start missing values
|
396 |
+
# guaranteed first csv is tminus1 day
|
397 |
+
if i == 0:
|
398 |
+
tmr = dailyfinal.index[0].date() + dt.timedelta(1)
|
399 |
+
tmr_dt = dt.datetime.combine(tmr, dt.time.min)
|
400 |
+
last_time_stamp_dt = dailyfinal.index[-1].to_pydatetime()
|
401 |
+
s_delta = (tmr_dt - last_time_stamp_dt).seconds
|
402 |
+
lastsample = dailyfinal.iloc[-1:]
|
403 |
+
lastsample.index = lastsample.index.shift(s_delta, "s")
|
404 |
+
else:
|
405 |
+
day_dt = dailyfinal.index[0].date()
|
406 |
+
day_str = day_dt.strftime("%Y-%m-%d")
|
407 |
+
nextday_str = (day_dt + dt.timedelta(1)).strftime("%Y-%m-%d")
|
408 |
+
if dailyfinal.index[0].second != 0:
|
409 |
+
# append last sample
|
410 |
+
dailyfinal = lastsample.append(dailyfinal)
|
411 |
+
# otherwise, just reindex and ffill
|
412 |
+
dailyfinal = dailyfinal.reindex(
|
413 |
+
pd.date_range(day_str, nextday_str, freq="1s")[:-1],
|
414 |
+
method="ffill",
|
415 |
+
)
|
416 |
+
# save reference info (guaranteed to be :59)
|
417 |
+
lastsample = dailyfinal.iloc[-1:]
|
418 |
+
lastsample.index = lastsample.index.shift(1, "s")
|
419 |
+
|
420 |
+
if dailyfinal.shape[0] != 86400:
|
421 |
+
raise ValueError("everyday should have 86400 datapoints")
|
422 |
+
|
423 |
+
# only save real startDate - endDate
|
424 |
+
security = security.append(dailyfinal)
|
425 |
+
|
426 |
+
security.ffill(inplace=True)
|
427 |
+
security["tic"] = tic
|
428 |
+
final_df = final_df.append(security)
|
429 |
+
return final_df
|
430 |
+
|
431 |
+
def fetch_n_combine(self, startDate, endDate, tickers):
|
432 |
+
# return combine_raw(fetchAggTrades(startDate, endDate, tickers))
|
433 |
+
mapping = self.fetch_aggTrades(startDate, endDate, tickers)
|
434 |
+
return self.combine_raw(mapping)
|
finnlp/data_processors/ccxt.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import calendar
|
2 |
+
from datetime import datetime
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import ccxt
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
|
9 |
+
from meta.data_processors._base import _Base
|
10 |
+
|
11 |
+
# from basic_processor import _Base
|
12 |
+
|
13 |
+
|
14 |
+
class Ccxt(_Base):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
data_source: str,
|
18 |
+
start_date: str,
|
19 |
+
end_date: str,
|
20 |
+
time_interval: str,
|
21 |
+
**kwargs,
|
22 |
+
):
|
23 |
+
super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
|
24 |
+
self.binance = ccxt.binance()
|
25 |
+
|
26 |
+
def download_data(
|
27 |
+
self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
|
28 |
+
):
|
29 |
+
crypto_column = pd.MultiIndex.from_product(
|
30 |
+
[ticker_list, ["open", "high", "low", "close", "volume"]]
|
31 |
+
)
|
32 |
+
first_time = True
|
33 |
+
for ticker in ticker_list:
|
34 |
+
start_dt = datetime.strptime(self.start_date, "%Y%m%d %H:%M:%S")
|
35 |
+
end_dt = datetime.strptime(self.end_date, "%Y%m%d %H:%M:%S")
|
36 |
+
start_timestamp = calendar.timegm(start_dt.utctimetuple())
|
37 |
+
end_timestamp = calendar.timegm(end_dt.utctimetuple())
|
38 |
+
if self.time_interval == "1Min":
|
39 |
+
date_list = [
|
40 |
+
datetime.utcfromtimestamp(float(time))
|
41 |
+
for time in range(start_timestamp, end_timestamp, 60 * 720)
|
42 |
+
]
|
43 |
+
else:
|
44 |
+
date_list = [
|
45 |
+
datetime.utcfromtimestamp(float(time))
|
46 |
+
for time in range(start_timestamp, end_timestamp, 60 * 1440)
|
47 |
+
]
|
48 |
+
df = self.ohlcv(date_list, ticker, self.time_interval)
|
49 |
+
if first_time:
|
50 |
+
dataset = pd.DataFrame(columns=crypto_column, index=df["time"].values)
|
51 |
+
first_time = False
|
52 |
+
temp_col = pd.MultiIndex.from_product(
|
53 |
+
[[ticker], ["open", "high", "low", "close", "volume"]]
|
54 |
+
)
|
55 |
+
dataset[temp_col] = df[["open", "high", "low", "close", "volume"]].values
|
56 |
+
print("Actual end time: " + str(df["time"].values[-1]))
|
57 |
+
self.dataframe = dataset
|
58 |
+
|
59 |
+
self.save_data(save_path)
|
60 |
+
|
61 |
+
print(
|
62 |
+
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
|
63 |
+
)
|
64 |
+
|
65 |
+
# def add_technical_indicators(self, df, pair_list, tech_indicator_list = [
|
66 |
+
# 'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'dx_30',
|
67 |
+
# 'close_30_sma', 'close_60_sma']):
|
68 |
+
# df = df.dropna()
|
69 |
+
# df = df.copy()
|
70 |
+
# column_list = [pair_list, ['open','high','low','close','volume']+(tech_indicator_list)]
|
71 |
+
# column = pd.MultiIndex.from_product(column_list)
|
72 |
+
# index_list = df.index
|
73 |
+
# dataset = pd.DataFrame(columns=column,index=index_list)
|
74 |
+
# for pair in pair_list:
|
75 |
+
# pair_column = pd.MultiIndex.from_product([[pair],['open','high','low','close','volume']])
|
76 |
+
# dataset[pair_column] = df[pair]
|
77 |
+
# temp_df = df[pair].reset_index().sort_values(by=['index'])
|
78 |
+
# temp_df = temp_df.rename(columns={'index':'date'})
|
79 |
+
# crypto_df = Sdf.retype(temp_df.copy())
|
80 |
+
# for indicator in tech_indicator_list:
|
81 |
+
# temp_indicator = crypto_df[indicator].values.tolist()
|
82 |
+
# dataset[(pair,indicator)] = temp_indicator
|
83 |
+
# print('Succesfully add technical indicators')
|
84 |
+
# return dataset
|
85 |
+
|
86 |
+
def df_to_ary(self, pair_list, tech_indicator_list=None):
|
87 |
+
if tech_indicator_list is None:
|
88 |
+
tech_indicator_list = [
|
89 |
+
"macd",
|
90 |
+
"boll_ub",
|
91 |
+
"boll_lb",
|
92 |
+
"rsi_30",
|
93 |
+
"dx_30",
|
94 |
+
"close_30_sma",
|
95 |
+
"close_60_sma",
|
96 |
+
]
|
97 |
+
df = self.dataframe
|
98 |
+
df = df.dropna()
|
99 |
+
date_ary = df.index.values
|
100 |
+
price_array = df[pd.MultiIndex.from_product([pair_list, ["close"]])].values
|
101 |
+
tech_array = df[
|
102 |
+
pd.MultiIndex.from_product([pair_list, tech_indicator_list])
|
103 |
+
].values
|
104 |
+
return price_array, tech_array, date_ary
|
105 |
+
|
106 |
+
def min_ohlcv(self, dt, pair, limit):
|
107 |
+
since = calendar.timegm(dt.utctimetuple()) * 1000
|
108 |
+
return self.binance.fetch_ohlcv(
|
109 |
+
symbol=pair, timeframe="1m", since=since, limit=limit
|
110 |
+
)
|
111 |
+
|
112 |
+
def ohlcv(self, dt, pair, period="1d"):
|
113 |
+
ohlcv = []
|
114 |
+
limit = 1000
|
115 |
+
if period == "1Min":
|
116 |
+
limit = 720
|
117 |
+
elif period == "1D":
|
118 |
+
limit = 1
|
119 |
+
elif period == "1H":
|
120 |
+
limit = 24
|
121 |
+
elif period == "5Min":
|
122 |
+
limit = 288
|
123 |
+
for i in dt:
|
124 |
+
start_dt = i
|
125 |
+
since = calendar.timegm(start_dt.utctimetuple()) * 1000
|
126 |
+
if period == "1Min":
|
127 |
+
ohlcv.extend(self.min_ohlcv(start_dt, pair, limit))
|
128 |
+
else:
|
129 |
+
ohlcv.extend(
|
130 |
+
self.binance.fetch_ohlcv(
|
131 |
+
symbol=pair, timeframe=period, since=since, limit=limit
|
132 |
+
)
|
133 |
+
)
|
134 |
+
df = pd.DataFrame(
|
135 |
+
ohlcv, columns=["time", "open", "high", "low", "close", "volume"]
|
136 |
+
)
|
137 |
+
df["time"] = [datetime.fromtimestamp(float(time) / 1000) for time in df["time"]]
|
138 |
+
df["open"] = df["open"].astype(np.float64)
|
139 |
+
df["high"] = df["high"].astype(np.float64)
|
140 |
+
df["low"] = df["low"].astype(np.float64)
|
141 |
+
df["close"] = df["close"].astype(np.float64)
|
142 |
+
df["volume"] = df["volume"].astype(np.float64)
|
143 |
+
return df
|
finnlp/data_processors/fx.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
from finta import TA
|
6 |
+
|
7 |
+
|
8 |
+
def add_time_feature(df, symbol, dt_col_name="time"):
|
9 |
+
"""read csv into df and index on time
|
10 |
+
dt_col_name can be any unit from minutes to day. time is the index of pd
|
11 |
+
must have pd columns [(time_col),(asset_col), Open,close,High,Low,day]
|
12 |
+
data_process will add additional time information: time(index), minute, hour, weekday, week, month,year, day(since 1970)
|
13 |
+
use StopLoss and ProfitTaken to simplify the action,
|
14 |
+
feed a fixed StopLoss (SL = 200) and PT = SL * ratio
|
15 |
+
action space: [action[0,2],ratio[0,10]]
|
16 |
+
rewards is point
|
17 |
+
|
18 |
+
add hourly, dayofweek(0-6, Sun-Sat)
|
19 |
+
Args:
|
20 |
+
file (str): file path/name.csv
|
21 |
+
"""
|
22 |
+
|
23 |
+
df["symbol"] = symbol
|
24 |
+
df["dt"] = pd.to_datetime(df[dt_col_name])
|
25 |
+
df.index = df["dt"]
|
26 |
+
df["minute"] = df["dt"].dt.minute
|
27 |
+
df["hour"] = df["dt"].dt.hour
|
28 |
+
df["weekday"] = df["dt"].dt.dayofweek
|
29 |
+
df["week"] = df["dt"].dt.isocalendar().week
|
30 |
+
df["month"] = df["dt"].dt.month
|
31 |
+
df["year"] = df["dt"].dt.year
|
32 |
+
df["day"] = df["dt"].dt.day
|
33 |
+
# df = df.set_index('dt')
|
34 |
+
return df
|
35 |
+
|
36 |
+
|
37 |
+
# 'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'dx_30','close_30_sma', 'close_60_sma'
|
38 |
+
def tech_indictors(df):
|
39 |
+
df["macd"] = TA.MACD(df).SIGNAL
|
40 |
+
df["boll_ub"] = TA.BBANDS(df).BB_UPPER
|
41 |
+
df["boll_lb"] = TA.BBANDS(df).BB_LOWER
|
42 |
+
df["rsi_30"] = TA.RSI(df, period=30)
|
43 |
+
df["dx_30"] = TA.ADX(df, period=30)
|
44 |
+
df["close_30_sma"] = TA.SMA(df, period=30)
|
45 |
+
df["close_60_sma"] = TA.SMA(df, period=60)
|
46 |
+
|
47 |
+
# fill NaN to 0
|
48 |
+
df = df.fillna(0)
|
49 |
+
print(
|
50 |
+
f"--------df head - tail ----------------\n{df.head(3)}\n{df.tail(3)}\n---------------------------------"
|
51 |
+
)
|
52 |
+
|
53 |
+
return df
|
54 |
+
|
55 |
+
|
56 |
+
def split_timeserious(df, key_ts="dt", freq="W", symbol=""):
|
57 |
+
"""import df and split into hour, daily, weekly, monthly based and
|
58 |
+
save into subfolder
|
59 |
+
|
60 |
+
Args:
|
61 |
+
df (pandas df with timestamp is part of multi index):
|
62 |
+
spliter (str): H, D, W, M, Y
|
63 |
+
"""
|
64 |
+
|
65 |
+
freq_name = {
|
66 |
+
"H": "hourly",
|
67 |
+
"D": "daily",
|
68 |
+
"W": "weekly",
|
69 |
+
"M": "monthly",
|
70 |
+
"Y": "Yearly",
|
71 |
+
}
|
72 |
+
for count, (n, g) in enumerate(df.groupby(pd.Grouper(level=key_ts, freq=freq))):
|
73 |
+
p = f"./data/split/{symbol}/{freq_name[freq]}"
|
74 |
+
os.makedirs(p, exist_ok=True)
|
75 |
+
# fname = f'{symbol}_{n:%Y%m%d}_{freq}_{count}.csv'
|
76 |
+
fname = f"{symbol}_{n:%Y}_{count}.csv"
|
77 |
+
fn = f"{p}/{fname}"
|
78 |
+
print(f"save to:{fn}")
|
79 |
+
g.reset_index(drop=True, inplace=True)
|
80 |
+
g.drop(columns=["dt"], inplace=True)
|
81 |
+
g.to_csv(fn)
|
82 |
+
return
|
83 |
+
|
84 |
+
|
85 |
+
"""
|
86 |
+
python ./neo_finrl/data_processors/fx.py GBPUSD W ./data/raw/GBPUSD_raw.csv
|
87 |
+
symbol="GBPUSD"
|
88 |
+
freq = [H, D, W, M]
|
89 |
+
file .csv, column names [time, Open, High, Low, Close, Vol]
|
90 |
+
"""
|
91 |
+
if __name__ == "__main__":
|
92 |
+
symbol, freq, file = sys.argv[1], sys.argv[2], sys.argv[3]
|
93 |
+
print(f"processing... symbol:{symbol} freq:{freq} file:{file}")
|
94 |
+
try:
|
95 |
+
df = pd.read_csv(file)
|
96 |
+
except Exception:
|
97 |
+
print(f"No such file or directory: {file}")
|
98 |
+
exit(0)
|
99 |
+
df = add_time_feature(df, symbol=symbol, dt_col_name="time")
|
100 |
+
df = tech_indictors(df)
|
101 |
+
split_timeserious(df, freq=freq, symbol=symbol)
|
102 |
+
print(f"Done!")
|
finnlp/data_processors/iexcloud.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from datetime import datetime
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import pandas_market_calendars as mcal
|
7 |
+
import pytz
|
8 |
+
import requests
|
9 |
+
|
10 |
+
from meta.data_processors._base import _Base
|
11 |
+
|
12 |
+
# from _base import _Base
|
13 |
+
|
14 |
+
|
15 |
+
class Iexcloud(_Base):
|
16 |
+
@classmethod
|
17 |
+
def _get_base_url(self, mode: str) -> str:
|
18 |
+
as1 = "mode must be sandbox or production."
|
19 |
+
assert mode in {"sandbox", "production"}, as1
|
20 |
+
|
21 |
+
if mode == "sandbox":
|
22 |
+
return "https://sandbox.iexapis.com"
|
23 |
+
|
24 |
+
return "https://cloud.iexapis.com"
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
data_source: str,
|
29 |
+
start_date: str,
|
30 |
+
end_date: str,
|
31 |
+
time_interval: str,
|
32 |
+
**kwargs,
|
33 |
+
):
|
34 |
+
super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
|
35 |
+
self.base_url = self._get_base_url(mode=kwargs["mode"])
|
36 |
+
self.token = kwargs["token"] or os.environ.get("IEX_TOKEN")
|
37 |
+
|
38 |
+
def download_data(
|
39 |
+
self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
|
40 |
+
):
|
41 |
+
"""Returns end of day historical data for up to 15 years.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
ticker_list (List[str]): List of the tickers to retrieve information.
|
45 |
+
start_date (str): Oldest date of the range.
|
46 |
+
end_date (str): Latest date of the range.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
pd.DataFrame: A pandas dataframe with end of day historical data
|
50 |
+
for the specified tickers with the following columns:
|
51 |
+
date, tic, open, high, low, close, adjusted_close, volume.
|
52 |
+
|
53 |
+
Examples:
|
54 |
+
kwargs['mode'] = 'sandbox'
|
55 |
+
kwargs['token'] = 'Tsk_d633e2ff10d463...'
|
56 |
+
>>> iex_dloader = Iexcloud(data_source='iexcloud', **kwargs)
|
57 |
+
>>> iex_dloader.download_data(ticker_list=["AAPL", "NVDA"],
|
58 |
+
start_date='2014-01-01',
|
59 |
+
end_date='2021-12-12',
|
60 |
+
time_interval = '1D')
|
61 |
+
"""
|
62 |
+
assert self.time_interval == "1D" # one day
|
63 |
+
|
64 |
+
price_data = pd.DataFrame()
|
65 |
+
|
66 |
+
query_params = {
|
67 |
+
"token": self.token,
|
68 |
+
}
|
69 |
+
|
70 |
+
if self.start_date and self.end_date:
|
71 |
+
query_params["from"] = self.start_date
|
72 |
+
query_params["to"] = self.end_date
|
73 |
+
|
74 |
+
for stock in ticker_list:
|
75 |
+
end_point = f"{self.base_url}/stable/time-series/HISTORICAL_PRICES/{stock}"
|
76 |
+
|
77 |
+
response = requests.get(
|
78 |
+
url=end_point,
|
79 |
+
params=query_params,
|
80 |
+
)
|
81 |
+
if response.status_code != 200:
|
82 |
+
raise requests.exceptions.RequestException(response.text)
|
83 |
+
|
84 |
+
temp = pd.DataFrame.from_dict(data=response.json())
|
85 |
+
temp["ticker"] = stock
|
86 |
+
price_data = price_data.append(temp)
|
87 |
+
price_data = price_data[
|
88 |
+
[
|
89 |
+
"date",
|
90 |
+
"ticker",
|
91 |
+
"open",
|
92 |
+
"high",
|
93 |
+
"low",
|
94 |
+
"close",
|
95 |
+
"fclose",
|
96 |
+
"volume",
|
97 |
+
]
|
98 |
+
]
|
99 |
+
price_data = price_data.rename(
|
100 |
+
columns={
|
101 |
+
"ticker": "tic",
|
102 |
+
"date": "time",
|
103 |
+
"fclose": "adjusted_close",
|
104 |
+
}
|
105 |
+
)
|
106 |
+
|
107 |
+
price_data.date = price_data.date.map(
|
108 |
+
lambda x: datetime.fromtimestamp(x / 1000, pytz.UTC).strftime("%Y-%m-%d")
|
109 |
+
)
|
110 |
+
|
111 |
+
self.dataframe = price_data
|
112 |
+
|
113 |
+
self.save_data(save_path)
|
114 |
+
|
115 |
+
print(
|
116 |
+
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
|
117 |
+
)
|
118 |
+
|
119 |
+
def get_trading_days(self, start: str, end: str) -> List[str]:
|
120 |
+
"""Retrieves every trading day between two dates.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
start (str): Oldest date of the range.
|
124 |
+
end (str): Latest date of the range.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
List[str]: List of all trading days in YYYY-dd-mm format.
|
128 |
+
|
129 |
+
Examples:
|
130 |
+
>>> iex_dloader = Iexcloud(data_source='iexcloud',
|
131 |
+
mode='sandbox',
|
132 |
+
token='Tsk_d633e2ff10d463...')
|
133 |
+
>>> iex_dloader.get_trading_days(start='2014-01-01',
|
134 |
+
end='2021-12-12')
|
135 |
+
['2021-12-15', '2021-12-16', '2021-12-17']
|
136 |
+
"""
|
137 |
+
nyse = mcal.get_calendar("NYSE")
|
138 |
+
|
139 |
+
df = nyse.schedule(
|
140 |
+
start_date=pd.Timestamp(start, tz=pytz.UTC),
|
141 |
+
end_date=pd.Timestamp(end, tz=pytz.UTC),
|
142 |
+
)
|
143 |
+
return df.applymap(lambda x: x.strftime("%Y-%m-%d")).market_open.to_list()
|
finnlp/data_processors/joinquant.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import datetime
|
3 |
+
import os
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import jqdatasdk as jq
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from meta.data_processors._base import _Base
|
10 |
+
|
11 |
+
|
12 |
+
class Joinquant(_Base):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
data_source: str,
|
16 |
+
start_date: str,
|
17 |
+
end_date: str,
|
18 |
+
time_interval: str,
|
19 |
+
**kwargs,
|
20 |
+
):
|
21 |
+
super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
|
22 |
+
if "username" in kwargs.keys() and "password" in kwargs.keys():
|
23 |
+
jq.auth(kwargs["username"], kwargs["password"])
|
24 |
+
|
25 |
+
def download_data(
|
26 |
+
self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
|
27 |
+
):
|
28 |
+
# joinquant supports: '1m', '5m', '15m', '30m', '60m', '120m', '1d', '1w', '1M'。'1w' denotes one week,‘1M' denotes one month。
|
29 |
+
count = len(self.get_trading_days(self.start_date, self.end_date))
|
30 |
+
df = jq.get_bars(
|
31 |
+
security=ticker_list,
|
32 |
+
count=count,
|
33 |
+
unit=self.time_interval,
|
34 |
+
fields=["date", "open", "high", "low", "close", "volume"],
|
35 |
+
end_dt=self.end_date,
|
36 |
+
)
|
37 |
+
df = df.reset_index().rename(columns={"level_0": "tic"})
|
38 |
+
self.dataframe = df
|
39 |
+
|
40 |
+
self.save_data(save_path)
|
41 |
+
|
42 |
+
print(
|
43 |
+
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
|
44 |
+
)
|
45 |
+
|
46 |
+
def preprocess(df, stock_list):
|
47 |
+
n = len(stock_list)
|
48 |
+
N = df.shape[0]
|
49 |
+
assert N % n == 0
|
50 |
+
d = int(N / n)
|
51 |
+
stock1_ary = df.iloc[0:d, 1:].values
|
52 |
+
temp_ary = stock1_ary
|
53 |
+
for j in range(1, n):
|
54 |
+
stocki_ary = df.iloc[j * d : (j + 1) * d, 1:].values
|
55 |
+
temp_ary = np.hstack((temp_ary, stocki_ary))
|
56 |
+
return temp_ary
|
57 |
+
|
58 |
+
# start_day: str
|
59 |
+
# end_day: str
|
60 |
+
# output: list of str_of_trade_day, e.g., ['2021-09-01', '2021-09-02']
|
61 |
+
def get_trading_days(self, start_day: str, end_day: str) -> List[str]:
|
62 |
+
dates = jq.get_trade_days(start_day, end_day)
|
63 |
+
str_dates = []
|
64 |
+
for d in dates:
|
65 |
+
tmp = datetime.date.strftime(d, "%Y-%m-%d")
|
66 |
+
str_dates.append(tmp)
|
67 |
+
# str_dates = [date2str(dt) for dt in dates]
|
68 |
+
return str_dates
|
finnlp/data_processors/quandl.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import pytz
|
6 |
+
import quandl
|
7 |
+
import yfinance as yf
|
8 |
+
|
9 |
+
"""Reference: https://github.com/AI4Finance-LLC/FinRL"""
|
10 |
+
|
11 |
+
try:
|
12 |
+
import exchange_calendars as tc
|
13 |
+
except:
|
14 |
+
print(
|
15 |
+
"Cannot import exchange_calendars.",
|
16 |
+
"If you are using python>=3.7, please install it.",
|
17 |
+
)
|
18 |
+
import trading_calendars as tc
|
19 |
+
|
20 |
+
print("Use trading_calendars instead for yahoofinance processor..")
|
21 |
+
# from basic_processor import _Base
|
22 |
+
from meta.data_processors._base import _Base
|
23 |
+
from meta.data_processors._base import calc_time_zone
|
24 |
+
|
25 |
+
from meta.config import (
|
26 |
+
TIME_ZONE_SHANGHAI,
|
27 |
+
TIME_ZONE_USEASTERN,
|
28 |
+
TIME_ZONE_PARIS,
|
29 |
+
TIME_ZONE_BERLIN,
|
30 |
+
TIME_ZONE_JAKARTA,
|
31 |
+
TIME_ZONE_SELFDEFINED,
|
32 |
+
USE_TIME_ZONE_SELFDEFINED,
|
33 |
+
BINANCE_BASE_URL,
|
34 |
+
)
|
35 |
+
|
36 |
+
TIME_ZONE_SELFDEFINED = TIME_ZONE_USEASTERN # If neither of the above is your time zone, you should define it, and set USE_TIME_ZONE_SELFDEFINED 1.
|
37 |
+
USE_TIME_ZONE_SELFDEFINED = 1 # 0 (default) or 1 (use the self defined)
|
38 |
+
|
39 |
+
|
40 |
+
class Quandl(_Base):
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
data_source: str,
|
44 |
+
start_date: str,
|
45 |
+
end_date: str,
|
46 |
+
time_interval: str,
|
47 |
+
**kwargs,
|
48 |
+
):
|
49 |
+
super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
|
50 |
+
|
51 |
+
def download_data(
|
52 |
+
self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
|
53 |
+
):
|
54 |
+
self.time_zone = calc_time_zone(
|
55 |
+
ticker_list, TIME_ZONE_SELFDEFINED, USE_TIME_ZONE_SELFDEFINED
|
56 |
+
)
|
57 |
+
|
58 |
+
# Download and save the data in a pandas DataFrame:
|
59 |
+
# data_df = pd.DataFrame()
|
60 |
+
# # set paginate to True because Quandl limits tables API to 10,000 rows per call
|
61 |
+
# data = quandl.get_table('ZACKS/FC', paginate=True, ticker=ticker_list, per_end_date={'gte': '2021-09-01'}, qopts={'columns': ['ticker', 'per_end_date']})
|
62 |
+
# data = quandl.get('ZACKS/FC', ticker=ticker_list, start_date="2020-12-31", end_date="2021-12-31")
|
63 |
+
self.dataframe = quandl.get_table(
|
64 |
+
"ZACKS/FC",
|
65 |
+
ticker=ticker_list,
|
66 |
+
qopts={"columns": ["ticker", "date", "adjusted_close"]},
|
67 |
+
date={"gte": self.start_date, "lte": self.end_date},
|
68 |
+
paginate=True,
|
69 |
+
)
|
70 |
+
self.dataframe.dropna(inplace=True)
|
71 |
+
self.dataframe.reset_index(drop=True, inplace=True)
|
72 |
+
print("Shape of DataFrame: ", self.dataframe.shape)
|
73 |
+
# print("Display DataFrame: ", data_df.head())
|
74 |
+
|
75 |
+
self.dataframe.sort_values(by=["date", "ticker"], inplace=True)
|
76 |
+
self.dataframe.reset_index(drop=True, inplace=True)
|
77 |
+
|
78 |
+
self.save_data(save_path)
|
79 |
+
|
80 |
+
print(
|
81 |
+
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
|
82 |
+
)
|
83 |
+
|
84 |
+
# def get_trading_days(self, start, end):
|
85 |
+
#
|
finnlp/data_processors/quantconnect.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
from meta.config import BINANCE_BASE_URL
|
4 |
+
from meta.config import TIME_ZONE_BERLIN
|
5 |
+
from meta.config import TIME_ZONE_JAKARTA
|
6 |
+
from meta.config import TIME_ZONE_PARIS
|
7 |
+
from meta.config import TIME_ZONE_SELFDEFINED
|
8 |
+
from meta.config import TIME_ZONE_SHANGHAI
|
9 |
+
from meta.config import TIME_ZONE_USEASTERN
|
10 |
+
from meta.config import USE_TIME_ZONE_SELFDEFINED
|
11 |
+
from meta.data_processors._base import _Base
|
12 |
+
|
13 |
+
# from basic_processor import _Base
|
14 |
+
|
15 |
+
|
16 |
+
## The code of this file is used in website, not locally.
|
17 |
+
class Quantconnect(_Base):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
data_source: str,
|
21 |
+
start_date: str,
|
22 |
+
end_date: str,
|
23 |
+
time_interval: str,
|
24 |
+
**kwargs,
|
25 |
+
):
|
26 |
+
super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
|
27 |
+
|
28 |
+
# def data_fetch(start_time, end_time, stock_list, resolution=Resolution.Daily) :
|
29 |
+
# #resolution: Daily, Hour, Minute, Second
|
30 |
+
# qb = QuantBook()
|
31 |
+
# for stock in stock_list:
|
32 |
+
# qb.AddEquity(stock)
|
33 |
+
# history = qb.History(qb.Securities.Keys, start_time, end_time, resolution)
|
34 |
+
# return history
|
35 |
+
|
36 |
+
def download_data(
|
37 |
+
self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
|
38 |
+
):
|
39 |
+
# self.time_zone = calc_time_zone(ticker_list, TIME_ZONE_SELFDEFINED, USE_TIME_ZONE_SELFDEFINED)
|
40 |
+
|
41 |
+
# start_date = pd.Timestamp(start_date, tz=self.time_zone)
|
42 |
+
# end_date = pd.Timestamp(end_date, tz=self.time_zone) + pd.Timedelta(days=1)
|
43 |
+
qb = QuantBook()
|
44 |
+
for stock in ticker_list:
|
45 |
+
qb.AddEquity(stock)
|
46 |
+
history = qb.History(
|
47 |
+
qb.Securities.Keys,
|
48 |
+
self.start_date,
|
49 |
+
self.end_date,
|
50 |
+
self.time_interval,
|
51 |
+
)
|
52 |
+
self.dataframe = history
|
53 |
+
|
54 |
+
self.save_data(save_path)
|
55 |
+
|
56 |
+
print(
|
57 |
+
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
|
58 |
+
)
|
59 |
+
|
60 |
+
# def preprocess(df, stock_list):
|
61 |
+
# df = df[['open','high','low','close','volume']]
|
62 |
+
# if_first_time = True
|
63 |
+
# for stock in stock_list:
|
64 |
+
# if if_first_time:
|
65 |
+
# ary = df.loc[stock].values
|
66 |
+
# if_first_time = False
|
67 |
+
# else:
|
68 |
+
# temp = df.loc[stock].values
|
69 |
+
# ary = np.hstack((ary,temp))
|
70 |
+
# return ary
|
finnlp/data_processors/ricequant.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import rqdatac as ricequant
|
4 |
+
|
5 |
+
from meta.data_processors._base import _Base
|
6 |
+
|
7 |
+
|
8 |
+
class Ricequant(_Base):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
data_source: str,
|
12 |
+
start_date: str,
|
13 |
+
end_date: str,
|
14 |
+
time_interval: str,
|
15 |
+
**kwargs,
|
16 |
+
):
|
17 |
+
super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
|
18 |
+
if kwargs["username"] is None or kwargs["password"] is None:
|
19 |
+
ricequant.init() # if the lisence is already set, you can init without username and password
|
20 |
+
else:
|
21 |
+
ricequant.init(
|
22 |
+
kwargs["username"], kwargs["password"]
|
23 |
+
) # init with username and password
|
24 |
+
|
25 |
+
def download_data(
|
26 |
+
self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
|
27 |
+
):
|
28 |
+
# download data by calling RiceQuant API
|
29 |
+
dataframe = ricequant.get_price(
|
30 |
+
ticker_list,
|
31 |
+
frequency=self.time_interval,
|
32 |
+
start_date=self.start_date,
|
33 |
+
end_date=self.end_date,
|
34 |
+
)
|
35 |
+
self.dataframe = dataframe
|
36 |
+
|
37 |
+
self.save_data(save_path)
|
38 |
+
|
39 |
+
print(
|
40 |
+
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
|
41 |
+
)
|
42 |
+
|
43 |
+
# def clean_data(self, df) -> pd.DataFrame:
|
44 |
+
# ''' RiceQuant data is already cleaned, we only need to transform data format here.
|
45 |
+
# No need for filling NaN data'''
|
46 |
+
# df = df.copy()
|
47 |
+
# # raw df uses multi-index (tic,time), reset it to single index (time)
|
48 |
+
# df = df.reset_index(level=[0,1])
|
49 |
+
# # rename column order_book_id to tic
|
50 |
+
# df = df.rename(columns={'order_book_id':'tic', 'datetime':'time'})
|
51 |
+
# # reserve columns needed
|
52 |
+
# df = df[['tic','time','open','high','low','close','volume']]
|
53 |
+
# # check if there is NaN values
|
54 |
+
# assert not df.isnull().values.any()
|
55 |
+
# return df
|
56 |
+
|
57 |
+
# def add_vix(self, data):
|
58 |
+
# print('VIX is NOT applicable to China A-shares')
|
59 |
+
# return data
|
60 |
+
|
61 |
+
# def calculate_turbulence(self, data, time_period=252):
|
62 |
+
# # can add other market assets
|
63 |
+
# df = data.copy()
|
64 |
+
# df_price_pivot = df.pivot(index="date", columns="tic", values="close")
|
65 |
+
# # use returns to calculate turbulence
|
66 |
+
# df_price_pivot = df_price_pivot.pct_change()
|
67 |
+
#
|
68 |
+
# unique_date = df.date.unique()
|
69 |
+
# # start after a fixed time period
|
70 |
+
# start = time_period
|
71 |
+
# turbulence_index = [0] * start
|
72 |
+
# # turbulence_index = [0]
|
73 |
+
# count = 0
|
74 |
+
# for i in range(start, len(unique_date)):
|
75 |
+
# current_price = df_price_pivot[df_price_pivot.index == unique_date[i]]
|
76 |
+
# # use one year rolling window to calcualte covariance
|
77 |
+
# hist_price = df_price_pivot[
|
78 |
+
# (df_price_pivot.index < unique_date[i])
|
79 |
+
# & (df_price_pivot.index >= unique_date[i - time_period])
|
80 |
+
# ]
|
81 |
+
# # Drop tickers which has number missing values more than the "oldest" ticker
|
82 |
+
# filtered_hist_price = hist_price.iloc[hist_price.isna().sum().min():].dropna(axis=1)
|
83 |
+
#
|
84 |
+
# cov_temp = filtered_hist_price.cov()
|
85 |
+
# current_temp = current_price[[x for x in filtered_hist_price]] - np.mean(filtered_hist_price, axis=0)
|
86 |
+
# temp = current_temp.values.dot(np.linalg.pinv(cov_temp)).dot(
|
87 |
+
# current_temp.values.T
|
88 |
+
# )
|
89 |
+
# if temp > 0:
|
90 |
+
# count += 1
|
91 |
+
# if count > 2:
|
92 |
+
# turbulence_temp = temp[0][0]
|
93 |
+
# else:
|
94 |
+
# # avoid large outlier because of the calculation just begins
|
95 |
+
# turbulence_temp = 0
|
96 |
+
# else:
|
97 |
+
# turbulence_temp = 0
|
98 |
+
# turbulence_index.append(turbulence_temp)
|
99 |
+
#
|
100 |
+
# turbulence_index = pd.DataFrame(
|
101 |
+
# {"date": df_price_pivot.index, "turbulence": turbulence_index}
|
102 |
+
# )
|
103 |
+
# return turbulence_index
|
104 |
+
#
|
105 |
+
# def add_turbulence(self, data, time_period=252):
|
106 |
+
# """
|
107 |
+
# add turbulence index from a precalcualted dataframe
|
108 |
+
# :param data: (df) pandas dataframe
|
109 |
+
# :return: (df) pandas dataframe
|
110 |
+
# """
|
111 |
+
# df = data.copy()
|
112 |
+
# turbulence_index = self.calculate_turbulence(df, time_period=time_period)
|
113 |
+
# df = df.merge(turbulence_index, on="date")
|
114 |
+
# df = df.sort_values(["date", "tic"]).reset_index(drop=True)
|
115 |
+
# return df
|
116 |
+
|
117 |
+
# def df_to_array(self, df, tech_indicator_list, if_vix):
|
118 |
+
# df = df.copy()
|
119 |
+
# unique_ticker = df.tic.unique()
|
120 |
+
# if_first_time = True
|
121 |
+
# for tic in unique_ticker:
|
122 |
+
# if if_first_time:
|
123 |
+
# price_array = df[df.tic==tic][['close']].values
|
124 |
+
# tech_array = df[df.tic==tic][tech_indicator_list].values
|
125 |
+
# #risk_array = df[df.tic==tic]['turbulence'].values
|
126 |
+
# if_first_time = False
|
127 |
+
# else:
|
128 |
+
# price_array = np.hstack([price_array, df[df.tic==tic][['close']].values])
|
129 |
+
# tech_array = np.hstack([tech_array, df[df.tic==tic][tech_indicator_list].values])
|
130 |
+
# print('Successfully transformed into array')
|
131 |
+
# return price_array, tech_array, None
|
finnlp/data_processors/tushare.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
warnings.filterwarnings("ignore")
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
import pandas as pd
|
10 |
+
from tqdm import tqdm
|
11 |
+
from matplotlib import pyplot as plt
|
12 |
+
|
13 |
+
import stockstats
|
14 |
+
import talib
|
15 |
+
from meta.data_processors._base import _Base
|
16 |
+
|
17 |
+
import tushare as ts
|
18 |
+
|
19 |
+
|
20 |
+
class Tushare(_Base):
|
21 |
+
"""
|
22 |
+
key-value in kwargs
|
23 |
+
----------
|
24 |
+
token : str
|
25 |
+
get from https://waditu.com/ after registration
|
26 |
+
adj: str
|
27 |
+
Whether to use adjusted closing price. Default is None.
|
28 |
+
If you want to use forward adjusted closing price or 前复权. pleses use 'qfq'
|
29 |
+
If you want to use backward adjusted closing price or 后复权. pleses use 'hfq'
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
data_source: str,
|
35 |
+
start_date: str,
|
36 |
+
end_date: str,
|
37 |
+
time_interval: str,
|
38 |
+
**kwargs,
|
39 |
+
):
|
40 |
+
super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
|
41 |
+
assert "token" in kwargs.keys(), "Please input token!"
|
42 |
+
self.token = kwargs["token"]
|
43 |
+
if "adj" in kwargs.keys():
|
44 |
+
self.adj = kwargs["adj"]
|
45 |
+
print(f"Using {self.adj} method.")
|
46 |
+
else:
|
47 |
+
self.adj = None
|
48 |
+
|
49 |
+
def get_data(self, id) -> pd.DataFrame:
|
50 |
+
# df1 = ts.pro_bar(ts_code=id, start_date=self.start_date,end_date='20180101')
|
51 |
+
# dfb=pd.concat([df, df1], ignore_index=True)
|
52 |
+
# print(dfb.shape)
|
53 |
+
return ts.pro_bar(
|
54 |
+
ts_code=id,
|
55 |
+
start_date=self.start_date,
|
56 |
+
end_date=self.end_date,
|
57 |
+
adj=self.adj,
|
58 |
+
)
|
59 |
+
|
60 |
+
def download_data(
|
61 |
+
self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
|
62 |
+
):
|
63 |
+
"""
|
64 |
+
`pd.DataFrame`
|
65 |
+
7 columns: A tick symbol, time, open, high, low, close and volume
|
66 |
+
for the specified stock ticker
|
67 |
+
"""
|
68 |
+
assert self.time_interval == "1d", "Not supported currently"
|
69 |
+
|
70 |
+
self.ticker_list = ticker_list
|
71 |
+
ts.set_token(self.token)
|
72 |
+
|
73 |
+
self.dataframe = pd.DataFrame()
|
74 |
+
for i in tqdm(ticker_list, total=len(ticker_list)):
|
75 |
+
# nonstandard_id = self.transfer_standard_ticker_to_nonstandard(i)
|
76 |
+
# df_temp = self.get_data(nonstandard_id)
|
77 |
+
df_temp = self.get_data(i)
|
78 |
+
self.dataframe = self.dataframe.append(df_temp)
|
79 |
+
# print("{} ok".format(i))
|
80 |
+
time.sleep(0.25)
|
81 |
+
|
82 |
+
self.dataframe.columns = [
|
83 |
+
"tic",
|
84 |
+
"time",
|
85 |
+
"open",
|
86 |
+
"high",
|
87 |
+
"low",
|
88 |
+
"close",
|
89 |
+
"pre_close",
|
90 |
+
"change",
|
91 |
+
"pct_chg",
|
92 |
+
"volume",
|
93 |
+
"amount",
|
94 |
+
]
|
95 |
+
self.dataframe.sort_values(by=["time", "tic"], inplace=True)
|
96 |
+
self.dataframe.reset_index(drop=True, inplace=True)
|
97 |
+
|
98 |
+
self.dataframe = self.dataframe[
|
99 |
+
["tic", "time", "open", "high", "low", "close", "volume"]
|
100 |
+
]
|
101 |
+
# self.dataframe.loc[:, 'tic'] = pd.DataFrame((self.dataframe['tic'].tolist()))
|
102 |
+
self.dataframe["time"] = pd.to_datetime(self.dataframe["time"], format="%Y%m%d")
|
103 |
+
self.dataframe["day"] = self.dataframe["time"].dt.dayofweek
|
104 |
+
self.dataframe["time"] = self.dataframe.time.apply(
|
105 |
+
lambda x: x.strftime("%Y-%m-%d")
|
106 |
+
)
|
107 |
+
|
108 |
+
self.dataframe.dropna(inplace=True)
|
109 |
+
self.dataframe.sort_values(by=["time", "tic"], inplace=True)
|
110 |
+
self.dataframe.reset_index(drop=True, inplace=True)
|
111 |
+
|
112 |
+
self.save_data(save_path)
|
113 |
+
|
114 |
+
print(
|
115 |
+
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
|
116 |
+
)
|
117 |
+
|
118 |
+
def data_split(self, df, start, end, target_date_col="time"):
|
119 |
+
"""
|
120 |
+
split the dataset into training or testing using time
|
121 |
+
:param data: (df) pandas dataframe, start, end
|
122 |
+
:return: (df) pandas dataframe
|
123 |
+
"""
|
124 |
+
data = df[(df[target_date_col] >= start) & (df[target_date_col] < end)]
|
125 |
+
data = data.sort_values([target_date_col, "tic"], ignore_index=True)
|
126 |
+
data.index = data[target_date_col].factorize()[0]
|
127 |
+
return data
|
128 |
+
|
129 |
+
def transfer_standard_ticker_to_nonstandard(self, ticker: str) -> str:
|
130 |
+
# "600000.XSHG" -> "600000.SH"
|
131 |
+
# "000612.XSHE" -> "000612.SZ"
|
132 |
+
n, alpha = ticker.split(".")
|
133 |
+
assert alpha in ["XSHG", "XSHE"], "Wrong alpha"
|
134 |
+
if alpha == "XSHG":
|
135 |
+
nonstandard_ticker = n + ".SH"
|
136 |
+
elif alpha == "XSHE":
|
137 |
+
nonstandard_ticker = n + ".SZ"
|
138 |
+
return nonstandard_ticker
|
139 |
+
|
140 |
+
def save_data(self, path):
|
141 |
+
if ".csv" in path:
|
142 |
+
path = path.split("/")
|
143 |
+
filename = path[-1]
|
144 |
+
path = "/".join(path[:-1] + [""])
|
145 |
+
else:
|
146 |
+
if path[-1] == "/":
|
147 |
+
filename = "dataset.csv"
|
148 |
+
else:
|
149 |
+
filename = "/dataset.csv"
|
150 |
+
|
151 |
+
os.makedirs(path, exist_ok=True)
|
152 |
+
self.dataframe.to_csv(path + filename, index=False)
|
153 |
+
|
154 |
+
def load_data(self, path):
|
155 |
+
assert ".csv" in path # only support csv format now
|
156 |
+
self.dataframe = pd.read_csv(path)
|
157 |
+
columns = self.dataframe.columns
|
158 |
+
assert (
|
159 |
+
"tic" in columns and "time" in columns and "close" in columns
|
160 |
+
) # input file must have "tic","time" and "close" columns
|
161 |
+
|
162 |
+
|
163 |
+
class ReturnPlotter:
|
164 |
+
"""
|
165 |
+
An easy-to-use plotting tool to plot cumulative returns over time.
|
166 |
+
Baseline supports equal weighting(default) and any stocks you want to use for comparison.
|
167 |
+
"""
|
168 |
+
|
169 |
+
def __init__(self, df_account_value, df_trade, start_date, end_date):
|
170 |
+
self.start = start_date
|
171 |
+
self.end = end_date
|
172 |
+
self.trade = df_trade
|
173 |
+
self.df_account_value = df_account_value
|
174 |
+
|
175 |
+
def get_baseline(self, ticket):
|
176 |
+
df = ts.get_hist_data(ticket, start=self.start, end=self.end)
|
177 |
+
df.loc[:, "dt"] = df.index
|
178 |
+
df.index = range(len(df))
|
179 |
+
df.sort_values(axis=0, by="dt", ascending=True, inplace=True)
|
180 |
+
df["time"] = pd.to_datetime(df["dt"], format="%Y-%m-%d")
|
181 |
+
return df
|
182 |
+
|
183 |
+
def plot(self, baseline_ticket=None):
|
184 |
+
"""
|
185 |
+
Plot cumulative returns over time.
|
186 |
+
use baseline_ticket to specify stock you want to use for comparison
|
187 |
+
(default: equal weighted returns)
|
188 |
+
"""
|
189 |
+
baseline_label = "Equal-weight portfolio"
|
190 |
+
tic2label = {"399300": "CSI 300 Index", "000016": "SSE 50 Index"}
|
191 |
+
if baseline_ticket:
|
192 |
+
# 使用指定ticket作为baseline
|
193 |
+
baseline_df = self.get_baseline(baseline_ticket)
|
194 |
+
baseline_date_list = baseline_df.time.dt.strftime("%Y-%m-%d").tolist()
|
195 |
+
df_date_list = self.df_account_value.time.tolist()
|
196 |
+
df_account_value = self.df_account_value[
|
197 |
+
self.df_account_value.time.isin(baseline_date_list)
|
198 |
+
]
|
199 |
+
baseline_df = baseline_df[baseline_df.time.isin(df_date_list)]
|
200 |
+
baseline = baseline_df.close.tolist()
|
201 |
+
baseline_label = tic2label.get(baseline_ticket, baseline_ticket)
|
202 |
+
ours = df_account_value.account_value.tolist()
|
203 |
+
else:
|
204 |
+
# 均等权重
|
205 |
+
all_date = self.trade.time.unique().tolist()
|
206 |
+
baseline = []
|
207 |
+
for day in all_date:
|
208 |
+
day_close = self.trade[self.trade["time"] == day].close.tolist()
|
209 |
+
avg_close = sum(day_close) / len(day_close)
|
210 |
+
baseline.append(avg_close)
|
211 |
+
ours = self.df_account_value.account_value.tolist()
|
212 |
+
|
213 |
+
ours = self.pct(ours)
|
214 |
+
baseline = self.pct(baseline)
|
215 |
+
|
216 |
+
days_per_tick = (
|
217 |
+
60 # you should scale this variable accroding to the total trading days
|
218 |
+
)
|
219 |
+
time = list(range(len(ours)))
|
220 |
+
datetimes = self.df_account_value.time.tolist()
|
221 |
+
ticks = [tick for t, tick in zip(time, datetimes) if t % days_per_tick == 0]
|
222 |
+
plt.title("Cumulative Returns")
|
223 |
+
plt.plot(time, ours, label="DDPG Agent", color="green")
|
224 |
+
plt.plot(time, baseline, label=baseline_label, color="grey")
|
225 |
+
plt.xticks([i * days_per_tick for i in range(len(ticks))], ticks, fontsize=7)
|
226 |
+
|
227 |
+
plt.xlabel("Date")
|
228 |
+
plt.ylabel("Cumulative Return")
|
229 |
+
|
230 |
+
plt.legend()
|
231 |
+
plt.show()
|
232 |
+
plt.savefig(f"plot_{baseline_ticket}.png")
|
233 |
+
|
234 |
+
def plot_all(self):
|
235 |
+
baseline_label = "Equal-weight portfolio"
|
236 |
+
tic2label = {"399300": "CSI 300 Index", "000016": "SSE 50 Index"}
|
237 |
+
|
238 |
+
# time lists
|
239 |
+
# algorithm time list
|
240 |
+
df_date_list = self.df_account_value.time.tolist()
|
241 |
+
|
242 |
+
# 399300 time list
|
243 |
+
csi300_df = self.get_baseline("399300")
|
244 |
+
csi300_date_list = csi300_df.time.dt.strftime("%Y-%m-%d").tolist()
|
245 |
+
|
246 |
+
# 000016 time list
|
247 |
+
sh50_df = self.get_baseline("000016")
|
248 |
+
sh50_date_list = sh50_df.time.dt.strftime("%Y-%m-%d").tolist()
|
249 |
+
|
250 |
+
# find intersection
|
251 |
+
all_date = sorted(
|
252 |
+
list(set(df_date_list) & set(csi300_date_list) & set(sh50_date_list))
|
253 |
+
)
|
254 |
+
|
255 |
+
# filter data
|
256 |
+
csi300_df = csi300_df[csi300_df.time.isin(all_date)]
|
257 |
+
baseline_300 = csi300_df.close.tolist()
|
258 |
+
baseline_label_300 = tic2label["399300"]
|
259 |
+
|
260 |
+
sh50_df = sh50_df[sh50_df.time.isin(all_date)]
|
261 |
+
baseline_50 = sh50_df.close.tolist()
|
262 |
+
baseline_label_50 = tic2label["000016"]
|
263 |
+
|
264 |
+
# 均等权重
|
265 |
+
baseline_equal_weight = []
|
266 |
+
for day in all_date:
|
267 |
+
day_close = self.trade[self.trade["time"] == day].close.tolist()
|
268 |
+
avg_close = sum(day_close) / len(day_close)
|
269 |
+
baseline_equal_weight.append(avg_close)
|
270 |
+
|
271 |
+
df_account_value = self.df_account_value[
|
272 |
+
self.df_account_value.time.isin(all_date)
|
273 |
+
]
|
274 |
+
ours = df_account_value.account_value.tolist()
|
275 |
+
|
276 |
+
ours = self.pct(ours)
|
277 |
+
baseline_300 = self.pct(baseline_300)
|
278 |
+
baseline_50 = self.pct(baseline_50)
|
279 |
+
baseline_equal_weight = self.pct(baseline_equal_weight)
|
280 |
+
|
281 |
+
days_per_tick = (
|
282 |
+
60 # you should scale this variable accroding to the total trading days
|
283 |
+
)
|
284 |
+
time = list(range(len(ours)))
|
285 |
+
datetimes = self.df_account_value.time.tolist()
|
286 |
+
ticks = [tick for t, tick in zip(time, datetimes) if t % days_per_tick == 0]
|
287 |
+
plt.title("Cumulative Returns")
|
288 |
+
plt.plot(time, ours, label="DDPG Agent", color="darkorange")
|
289 |
+
plt.plot(
|
290 |
+
time,
|
291 |
+
baseline_equal_weight,
|
292 |
+
label=baseline_label,
|
293 |
+
color="cornflowerblue",
|
294 |
+
) # equal weight
|
295 |
+
plt.plot(
|
296 |
+
time, baseline_300, label=baseline_label_300, color="lightgreen"
|
297 |
+
) # 399300
|
298 |
+
plt.plot(time, baseline_50, label=baseline_label_50, color="silver") # 000016
|
299 |
+
plt.xlabel("Date")
|
300 |
+
plt.ylabel("Cumulative Return")
|
301 |
+
|
302 |
+
plt.xticks([i * days_per_tick for i in range(len(ticks))], ticks, fontsize=7)
|
303 |
+
plt.legend()
|
304 |
+
plt.show()
|
305 |
+
plt.savefig("./plot_all.png")
|
306 |
+
|
307 |
+
def pct(self, l):
|
308 |
+
"""Get percentage"""
|
309 |
+
base = l[0]
|
310 |
+
return [x / base for x in l]
|
311 |
+
|
312 |
+
def get_return(self, df, value_col_name="account_value"):
|
313 |
+
df = copy.deepcopy(df)
|
314 |
+
df["daily_return"] = df[value_col_name].pct_change(1)
|
315 |
+
df["time"] = pd.to_datetime(df["time"], format="%Y-%m-%d")
|
316 |
+
df.set_index("time", inplace=True, drop=True)
|
317 |
+
df.index = df.index.tz_localize("UTC")
|
318 |
+
return pd.Series(df["daily_return"], index=df.index)
|
finnlp/data_processors/wrds.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import pytz
|
7 |
+
import wrds
|
8 |
+
|
9 |
+
try:
|
10 |
+
import exchange_calendars as tc
|
11 |
+
except:
|
12 |
+
print(
|
13 |
+
"Cannot import exchange_calendars.",
|
14 |
+
"If you are using python>=3.7, please install it.",
|
15 |
+
)
|
16 |
+
import trading_calendars as tc
|
17 |
+
|
18 |
+
print("Use trading_calendars instead for wrds processor.")
|
19 |
+
# from basic_processor import _Base
|
20 |
+
from meta.data_processors._base import _Base
|
21 |
+
|
22 |
+
pd.options.mode.chained_assignment = None
|
23 |
+
|
24 |
+
|
25 |
+
class Wrds(_Base):
|
26 |
+
# def __init__(self,if_offline=False):
|
27 |
+
# if not if_offline:
|
28 |
+
# self.db = wrds.Connection()
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
data_source: str,
|
32 |
+
start_date: str,
|
33 |
+
end_date: str,
|
34 |
+
time_interval: str,
|
35 |
+
**kwargs,
|
36 |
+
):
|
37 |
+
super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
|
38 |
+
if "if_offline" in kwargs.keys() and not kwargs["if_offline"]:
|
39 |
+
self.db = wrds.Connection()
|
40 |
+
|
41 |
+
def download_data(
|
42 |
+
self,
|
43 |
+
ticker_list: List[str],
|
44 |
+
if_save_tempfile=False,
|
45 |
+
filter_shares=0,
|
46 |
+
save_path: str = "./data/dataset.csv",
|
47 |
+
):
|
48 |
+
dates = self.get_trading_days(self.start_date, self.end_date)
|
49 |
+
print("Trading days: ")
|
50 |
+
print(dates)
|
51 |
+
first_time = True
|
52 |
+
empty = True
|
53 |
+
stock_set = tuple(ticker_list)
|
54 |
+
for i in dates:
|
55 |
+
x = self.data_fetch_wrds(i, stock_set, filter_shares, self.time_interval)
|
56 |
+
|
57 |
+
if not x[1]:
|
58 |
+
empty = False
|
59 |
+
dataset = x[0]
|
60 |
+
dataset = self.preprocess_to_ohlcv(
|
61 |
+
dataset, time_interval=(str(self.time_interval) + "S")
|
62 |
+
)
|
63 |
+
print("Data for date: " + i + " finished")
|
64 |
+
if first_time:
|
65 |
+
temp = dataset
|
66 |
+
first_time = False
|
67 |
+
else:
|
68 |
+
temp = pd.concat([temp, dataset])
|
69 |
+
if if_save_tempfile:
|
70 |
+
temp.to_csv("./temp.csv")
|
71 |
+
if empty:
|
72 |
+
raise ValueError("Empty Data under input parameters!")
|
73 |
+
result = temp
|
74 |
+
result = result.sort_values(by=["time", "tic"])
|
75 |
+
result = result.reset_index(drop=True)
|
76 |
+
self.dataframe = result
|
77 |
+
|
78 |
+
self.save_data(save_path)
|
79 |
+
|
80 |
+
print(
|
81 |
+
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
|
82 |
+
)
|
83 |
+
|
84 |
+
def preprocess_to_ohlcv(self, df, time_interval="60S"):
|
85 |
+
df = df[["date", "time_m", "sym_root", "size", "price"]]
|
86 |
+
tic_list = np.unique(df["sym_root"].values)
|
87 |
+
final_df = None
|
88 |
+
first_time = True
|
89 |
+
for i in range(len(tic_list)):
|
90 |
+
tic = tic_list[i]
|
91 |
+
time_list = []
|
92 |
+
temp_df = df[df["sym_root"] == tic]
|
93 |
+
for i in range(temp_df.shape[0]):
|
94 |
+
date = temp_df["date"].iloc[i]
|
95 |
+
time_m = temp_df["time_m"].iloc[i]
|
96 |
+
time = str(date) + " " + str(time_m)
|
97 |
+
try:
|
98 |
+
time = datetime.datetime.strptime(time, "%Y-%m-%d %H:%M:%S.%f")
|
99 |
+
except:
|
100 |
+
time = datetime.datetime.strptime(time, "%Y-%m-%d %H:%M:%S")
|
101 |
+
time_list.append(time)
|
102 |
+
temp_df["time"] = time_list
|
103 |
+
temp_df = temp_df.set_index("time")
|
104 |
+
data_ohlc = temp_df["price"].resample(time_interval).ohlc()
|
105 |
+
data_v = temp_df["size"].resample(time_interval).agg({"size": "sum"})
|
106 |
+
volume = data_v["size"].values
|
107 |
+
data_ohlc["volume"] = volume
|
108 |
+
data_ohlc["tic"] = tic
|
109 |
+
if first_time:
|
110 |
+
final_df = data_ohlc.reset_index()
|
111 |
+
first_time = False
|
112 |
+
else:
|
113 |
+
final_df = final_df.append(data_ohlc.reset_index(), ignore_index=True)
|
114 |
+
return final_df
|
115 |
+
|
116 |
+
def clean_data(self):
|
117 |
+
df = self.dataframe[["time", "open", "high", "low", "close", "volume", "tic"]]
|
118 |
+
# remove 16:00 data
|
119 |
+
tic_list = np.unique(df["tic"].values)
|
120 |
+
ary = df.values
|
121 |
+
rows_1600 = []
|
122 |
+
for i in range(ary.shape[0]):
|
123 |
+
row = ary[i]
|
124 |
+
time = row[0]
|
125 |
+
if str(time)[-8:] == "16:00:00":
|
126 |
+
rows_1600.append(i)
|
127 |
+
|
128 |
+
df = df.drop(rows_1600)
|
129 |
+
df = df.sort_values(by=["tic", "time"])
|
130 |
+
|
131 |
+
# check missing rows
|
132 |
+
tic_dic = {tic: [0, 0] for tic in tic_list}
|
133 |
+
ary = df.values
|
134 |
+
for i in range(ary.shape[0]):
|
135 |
+
row = ary[i]
|
136 |
+
volume = row[5]
|
137 |
+
tic = row[6]
|
138 |
+
if volume != 0:
|
139 |
+
tic_dic[tic][0] += 1
|
140 |
+
tic_dic[tic][1] += 1
|
141 |
+
constant = np.unique(df["time"].values).shape[0]
|
142 |
+
nan_tics = [tic for tic, value in tic_dic.items() if value[1] != constant]
|
143 |
+
# fill missing rows
|
144 |
+
normal_time = np.unique(df["time"].values)
|
145 |
+
|
146 |
+
df2 = df.copy()
|
147 |
+
for tic in nan_tics:
|
148 |
+
tic_time = df[df["tic"] == tic]["time"].values
|
149 |
+
missing_time = [i for i in normal_time if i not in tic_time]
|
150 |
+
for time in missing_time:
|
151 |
+
temp_df = pd.DataFrame(
|
152 |
+
[[time, np.nan, np.nan, np.nan, np.nan, 0, tic]],
|
153 |
+
columns=[
|
154 |
+
"time",
|
155 |
+
"open",
|
156 |
+
"high",
|
157 |
+
"low",
|
158 |
+
"close",
|
159 |
+
"volume",
|
160 |
+
"tic",
|
161 |
+
],
|
162 |
+
)
|
163 |
+
df2 = df2.append(temp_df, ignore_index=True)
|
164 |
+
|
165 |
+
# fill nan data
|
166 |
+
df = df2.sort_values(by=["tic", "time"])
|
167 |
+
for i in range(df.shape[0]):
|
168 |
+
if float(df.iloc[i]["volume"]) == 0:
|
169 |
+
previous_close = df.iloc[i - 1]["close"]
|
170 |
+
if str(previous_close) == "nan":
|
171 |
+
raise ValueError("Error nan price")
|
172 |
+
df.iloc[i, 1] = previous_close
|
173 |
+
df.iloc[i, 2] = previous_close
|
174 |
+
df.iloc[i, 3] = previous_close
|
175 |
+
df.iloc[i, 4] = previous_close
|
176 |
+
# check if nan
|
177 |
+
ary = df[["open", "high", "low", "close", "volume"]].values
|
178 |
+
assert np.isnan(np.min(ary)) == False
|
179 |
+
# final preprocess
|
180 |
+
df = df[["time", "open", "high", "low", "close", "volume", "tic"]]
|
181 |
+
df = df.reset_index(drop=True)
|
182 |
+
print("Data clean finished")
|
183 |
+
self.dataframe = df
|
184 |
+
|
185 |
+
def get_trading_days(self, start, end):
|
186 |
+
nyse = tc.get_calendar("NYSE")
|
187 |
+
df = nyse.sessions_in_range(
|
188 |
+
pd.Timestamp(start, tz=pytz.UTC), pd.Timestamp(end, tz=pytz.UTC)
|
189 |
+
)
|
190 |
+
return [str(day)[:10] for day in df]
|
191 |
+
|
192 |
+
def data_fetch_wrds(
|
193 |
+
self,
|
194 |
+
date="2021-05-01",
|
195 |
+
stock_set=("AAPL"),
|
196 |
+
filter_shares=0,
|
197 |
+
time_interval=60,
|
198 |
+
):
|
199 |
+
# start_date, end_date should be in the same year
|
200 |
+
current_date = datetime.datetime.strptime(date, "%Y-%m-%d")
|
201 |
+
lib = "taqm_" + str(current_date.year) # taqm_2021
|
202 |
+
table = "ctm_" + current_date.strftime("%Y%m%d") # ctm_20210501
|
203 |
+
|
204 |
+
parm = {"syms": stock_set, "num_shares": filter_shares}
|
205 |
+
try:
|
206 |
+
data = self.db.raw_sql(
|
207 |
+
"select * from "
|
208 |
+
+ lib
|
209 |
+
+ "."
|
210 |
+
+ table
|
211 |
+
+ " where sym_root in %(syms)s and time_m between '9:30:00' and '16:00:00' and size > %(num_shares)s and sym_suffix is null",
|
212 |
+
params=parm,
|
213 |
+
)
|
214 |
+
if_empty = False
|
215 |
+
return data, if_empty
|
216 |
+
except:
|
217 |
+
print("Data for date: " + date + " error")
|
218 |
+
if_empty = True
|
219 |
+
return None, if_empty
|
220 |
+
|
221 |
+
# def add_technical_indicator(self, df, tech_indicator_list = [
|
222 |
+
# 'macd', 'boll_ub', 'boll_lb', 'rsi_30', 'dx_30',
|
223 |
+
# 'close_30_sma', 'close_60_sma']):
|
224 |
+
# df = df.rename(columns={'time':'date'})
|
225 |
+
# df = df.copy()
|
226 |
+
# df = df.sort_values(by=['tic', 'date'])
|
227 |
+
# stock = Sdf.retype(df.copy())
|
228 |
+
# unique_ticker = stock.tic.unique()
|
229 |
+
# tech_indicator_list = tech_indicator_list
|
230 |
+
#
|
231 |
+
# for indicator in tech_indicator_list:
|
232 |
+
# indicator_df = pd.DataFrame()
|
233 |
+
# for i in range(len(unique_ticker)):
|
234 |
+
# # print(unique_ticker[i], i)
|
235 |
+
# temp_indicator = stock[stock.tic == unique_ticker[i]][indicator]
|
236 |
+
# temp_indicator = pd.DataFrame(temp_indicator)
|
237 |
+
# temp_indicator['tic'] = unique_ticker[i]
|
238 |
+
# # print(len(df[df.tic == unique_ticker[i]]['date'].to_list()))
|
239 |
+
# temp_indicator['date'] = df[df.tic == unique_ticker[i]]['date'].to_list()
|
240 |
+
# indicator_df = indicator_df.append(
|
241 |
+
# temp_indicator, ignore_index=True
|
242 |
+
# )
|
243 |
+
# df = df.merge(indicator_df[['tic', 'date', indicator]], on=['tic', 'date'], how='left')
|
244 |
+
# df = df.sort_values(by=['date', 'tic'])
|
245 |
+
# print('Succesfully add technical indicators')
|
246 |
+
# return df
|
247 |
+
|
248 |
+
# def calculate_turbulence(self,data, time_period=252):
|
249 |
+
# # can add other market assets
|
250 |
+
# df = data.copy()
|
251 |
+
# df_price_pivot = df.pivot(index="date", columns="tic", values="close")
|
252 |
+
# # use returns to calculate turbulence
|
253 |
+
# df_price_pivot = df_price_pivot.pct_change()
|
254 |
+
#
|
255 |
+
# unique_date = df.date.unique()
|
256 |
+
# # start after a fixed time period
|
257 |
+
# start = time_period
|
258 |
+
# turbulence_index = [0] * start
|
259 |
+
# # turbulence_index = [0]
|
260 |
+
# count = 0
|
261 |
+
# for i in range(start, len(unique_date)):
|
262 |
+
# current_price = df_price_pivot[df_price_pivot.index == unique_date[i]]
|
263 |
+
# # use one year rolling window to calcualte covariance
|
264 |
+
# hist_price = df_price_pivot[
|
265 |
+
# (df_price_pivot.index < unique_date[i])
|
266 |
+
# & (df_price_pivot.index >= unique_date[i - time_period])
|
267 |
+
# ]
|
268 |
+
# # Drop tickers which has number missing values more than the "oldest" ticker
|
269 |
+
# filtered_hist_price = hist_price.iloc[hist_price.isna().sum().min():].dropna(axis=1)
|
270 |
+
#
|
271 |
+
# cov_temp = filtered_hist_price.cov()
|
272 |
+
# current_temp = current_price[[x for x in filtered_hist_price]] - np.mean(filtered_hist_price, axis=0)
|
273 |
+
# temp = current_temp.values.dot(np.linalg.pinv(cov_temp)).dot(
|
274 |
+
# current_temp.values.T
|
275 |
+
# )
|
276 |
+
# if temp > 0:
|
277 |
+
# count += 1
|
278 |
+
# if count > 2:
|
279 |
+
# turbulence_temp = temp[0][0]
|
280 |
+
# else:
|
281 |
+
# # avoid large outlier because of the calculation just begins
|
282 |
+
# turbulence_temp = 0
|
283 |
+
# else:
|
284 |
+
# turbulence_temp = 0
|
285 |
+
# turbulence_index.append(turbulence_temp)
|
286 |
+
#
|
287 |
+
# turbulence_index = pd.DataFrame(
|
288 |
+
# {"date": df_price_pivot.index, "turbulence": turbulence_index}
|
289 |
+
# )
|
290 |
+
# return turbulence_index
|
291 |
+
#
|
292 |
+
# def add_turbulence(self,data, time_period=252):
|
293 |
+
# """
|
294 |
+
# add turbulence index from a precalcualted dataframe
|
295 |
+
# :param data: (df) pandas dataframe
|
296 |
+
# :return: (df) pandas dataframe
|
297 |
+
# """
|
298 |
+
# df = data.copy()
|
299 |
+
# turbulence_index = self.calculate_turbulence(df, time_period=time_period)
|
300 |
+
# df = df.merge(turbulence_index, on="date")
|
301 |
+
# df = df.sort_values(["date", "tic"]).reset_index(drop=True)
|
302 |
+
# return df
|
303 |
+
|
304 |
+
# def add_vix(self, data):
|
305 |
+
# vix_df = self.download_data(['vix'], self.start, self.end_date, self.time_interval)
|
306 |
+
# cleaned_vix = self.clean_data(vix_df)
|
307 |
+
# vix = cleaned_vix[['date','close']]
|
308 |
+
#
|
309 |
+
# df = data.copy()
|
310 |
+
# df = df.merge(vix, on="date")
|
311 |
+
# df = df.sort_values(["date", "tic"]).reset_index(drop=True)
|
312 |
+
#
|
313 |
+
# return df
|
314 |
+
|
315 |
+
# def df_to_array(self,df,tech_indicator_list):
|
316 |
+
# unique_ticker = df.tic.unique()
|
317 |
+
# print(unique_ticker)
|
318 |
+
# if_first_time = True
|
319 |
+
# for tic in unique_ticker:
|
320 |
+
# if if_first_time:
|
321 |
+
# price_array = df[df.tic==tic][['close']].values
|
322 |
+
# #price_ary = df[df.tic==tic]['close'].values
|
323 |
+
# tech_array = df[df.tic==tic][tech_indicator_list].values
|
324 |
+
# risk_array = df[df.tic==tic]['turbulence'].values
|
325 |
+
# if_first_time = False
|
326 |
+
# else:
|
327 |
+
# price_array = np.hstack([price_array, df[df.tic==tic][['close']].values])
|
328 |
+
# tech_array = np.hstack([tech_array, df[df.tic==tic][tech_indicator_list].values])
|
329 |
+
# print('Successfully transformed into array')
|
330 |
+
# return price_array, tech_array, risk_array
|
finnlp/data_processors/yahoofinance.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import pytz
|
6 |
+
import yfinance as yf
|
7 |
+
|
8 |
+
try:
|
9 |
+
import exchange_calendars as tc
|
10 |
+
except:
|
11 |
+
print(
|
12 |
+
"Cannot import exchange_calendars.",
|
13 |
+
"If you are using python>=3.7, please install it.",
|
14 |
+
)
|
15 |
+
import trading_calendars as tc
|
16 |
+
|
17 |
+
print("Use trading_calendars instead for yahoofinance processor..")
|
18 |
+
|
19 |
+
from finnlp.utils.config import (
|
20 |
+
BINANCE_BASE_URL,
|
21 |
+
TIME_ZONE_BERLIN,
|
22 |
+
TIME_ZONE_JAKARTA,
|
23 |
+
TIME_ZONE_PARIS,
|
24 |
+
TIME_ZONE_SELFDEFINED,
|
25 |
+
TIME_ZONE_SHANGHAI,
|
26 |
+
TIME_ZONE_USEASTERN,
|
27 |
+
USE_TIME_ZONE_SELFDEFINED,
|
28 |
+
)
|
29 |
+
from finnlp.data_processors._base import _Base, calc_time_zone
|
30 |
+
|
31 |
+
|
32 |
+
class Yahoofinance(_Base):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
data_source: str,
|
36 |
+
start_date: str,
|
37 |
+
end_date: str,
|
38 |
+
time_interval: str,
|
39 |
+
**kwargs,
|
40 |
+
):
|
41 |
+
super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
|
42 |
+
|
43 |
+
def download_data(
|
44 |
+
self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
|
45 |
+
):
|
46 |
+
self.time_zone = calc_time_zone(
|
47 |
+
ticker_list, TIME_ZONE_SELFDEFINED, USE_TIME_ZONE_SELFDEFINED
|
48 |
+
)
|
49 |
+
self.dataframe = pd.DataFrame()
|
50 |
+
for tic in ticker_list:
|
51 |
+
temp_df = yf.download(
|
52 |
+
tic,
|
53 |
+
start=self.start_date,
|
54 |
+
end=self.end_date,
|
55 |
+
interval=self.time_interval,
|
56 |
+
)
|
57 |
+
temp_df["tic"] = tic
|
58 |
+
self.dataframe = pd.concat([self.dataframe, temp_df], axis=0, join="outer")
|
59 |
+
self.dataframe.reset_index(inplace=True)
|
60 |
+
try:
|
61 |
+
self.dataframe.columns = [
|
62 |
+
"date",
|
63 |
+
"open",
|
64 |
+
"high",
|
65 |
+
"low",
|
66 |
+
"close",
|
67 |
+
"adjusted_close",
|
68 |
+
"volume",
|
69 |
+
"tic",
|
70 |
+
]
|
71 |
+
except NotImplementedError:
|
72 |
+
print("the features are not supported currently")
|
73 |
+
self.dataframe["day"] = self.dataframe["date"].dt.dayofweek
|
74 |
+
print(self.dataframe)
|
75 |
+
self.dataframe["date"] = self.dataframe.date.apply(
|
76 |
+
lambda x: x.strftime("%Y-%m-%d")
|
77 |
+
)
|
78 |
+
self.dataframe.dropna(inplace=True)
|
79 |
+
self.dataframe.reset_index(drop=True, inplace=True)
|
80 |
+
print("Shape of DataFrame: ", self.dataframe.shape)
|
81 |
+
self.dataframe.sort_values(by=["date", "tic"], inplace=True)
|
82 |
+
self.dataframe.reset_index(drop=True, inplace=True)
|
83 |
+
|
84 |
+
self.save_data(save_path)
|
85 |
+
|
86 |
+
print(
|
87 |
+
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
|
88 |
+
)
|
89 |
+
|
90 |
+
def clean_data(self):
|
91 |
+
df = self.dataframe.copy()
|
92 |
+
df = df.rename(columns={"date": "time"})
|
93 |
+
time_interval = self.time_interval
|
94 |
+
tic_list = np.unique(df.tic.values)
|
95 |
+
trading_days = self.get_trading_days(start=self.start_date, end=self.end_date)
|
96 |
+
if time_interval == "1D":
|
97 |
+
times = trading_days
|
98 |
+
elif time_interval == "1Min":
|
99 |
+
times = []
|
100 |
+
for day in trading_days:
|
101 |
+
current_time = pd.Timestamp(day + " 09:30:00").tz_localize(
|
102 |
+
self.time_zone
|
103 |
+
)
|
104 |
+
for _ in range(390):
|
105 |
+
times.append(current_time)
|
106 |
+
current_time += pd.Timedelta(minutes=1)
|
107 |
+
else:
|
108 |
+
raise ValueError(
|
109 |
+
"Data clean at given time interval is not supported for YahooFinance data."
|
110 |
+
)
|
111 |
+
new_df = pd.DataFrame()
|
112 |
+
for tic in tic_list:
|
113 |
+
print(("Clean data for ") + tic)
|
114 |
+
tmp_df = pd.DataFrame(
|
115 |
+
columns=[
|
116 |
+
"open",
|
117 |
+
"high",
|
118 |
+
"low",
|
119 |
+
"close",
|
120 |
+
"adjusted_close",
|
121 |
+
"volume",
|
122 |
+
],
|
123 |
+
index=times,
|
124 |
+
)
|
125 |
+
# get data for current ticker
|
126 |
+
tic_df = df[df.tic == tic]
|
127 |
+
# fill empty DataFrame using orginal data
|
128 |
+
for i in range(tic_df.shape[0]):
|
129 |
+
tmp_df.loc[tic_df.iloc[i]["time"]] = tic_df.iloc[i][
|
130 |
+
[
|
131 |
+
"open",
|
132 |
+
"high",
|
133 |
+
"low",
|
134 |
+
"close",
|
135 |
+
"adjusted_close",
|
136 |
+
"volume",
|
137 |
+
]
|
138 |
+
]
|
139 |
+
|
140 |
+
# if close on start date is NaN, fill data with first valid close
|
141 |
+
# and set volume to 0.
|
142 |
+
if str(tmp_df.iloc[0]["close"]) == "nan":
|
143 |
+
print("NaN data on start date, fill using first valid data.")
|
144 |
+
for i in range(tmp_df.shape[0]):
|
145 |
+
if str(tmp_df.iloc[i]["close"]) != "nan":
|
146 |
+
first_valid_close = tmp_df.iloc[i]["close"]
|
147 |
+
first_valid_adjclose = tmp_df.iloc[i]["adjusted_close"]
|
148 |
+
|
149 |
+
tmp_df.iloc[0] = [
|
150 |
+
first_valid_close,
|
151 |
+
first_valid_close,
|
152 |
+
first_valid_close,
|
153 |
+
first_valid_close,
|
154 |
+
first_valid_adjclose,
|
155 |
+
0.0,
|
156 |
+
]
|
157 |
+
|
158 |
+
# fill NaN data with previous close and set volume to 0.
|
159 |
+
for i in range(tmp_df.shape[0]):
|
160 |
+
if str(tmp_df.iloc[i]["close"]) == "nan":
|
161 |
+
previous_close = tmp_df.iloc[i - 1]["close"]
|
162 |
+
previous_adjusted_close = tmp_df.iloc[i - 1]["adjusted_close"]
|
163 |
+
if str(previous_close) == "nan":
|
164 |
+
raise ValueError
|
165 |
+
tmp_df.iloc[i] = [
|
166 |
+
previous_close,
|
167 |
+
previous_close,
|
168 |
+
previous_close,
|
169 |
+
previous_close,
|
170 |
+
previous_adjusted_close,
|
171 |
+
0.0,
|
172 |
+
]
|
173 |
+
|
174 |
+
# merge single ticker data to new DataFrame
|
175 |
+
tmp_df = tmp_df.astype(float)
|
176 |
+
tmp_df["tic"] = tic
|
177 |
+
new_df = new_df.append(tmp_df)
|
178 |
+
|
179 |
+
print(("Data clean for ") + tic + (" is finished."))
|
180 |
+
|
181 |
+
# reset index and rename columns
|
182 |
+
new_df = new_df.reset_index()
|
183 |
+
new_df = new_df.rename(columns={"index": "time"})
|
184 |
+
print("Data clean all finished!")
|
185 |
+
self.dataframe = new_df
|
186 |
+
|
187 |
+
def get_trading_days(self, start, end):
|
188 |
+
nyse = tc.get_calendar("NYSE")
|
189 |
+
df = nyse.sessions_in_range(pd.Timestamp(start), pd.Timestamp(end))
|
190 |
+
return [str(day)[:10] for day in df]
|