Spaces:
Runtime error
Runtime error
import copy | |
import os | |
import time | |
import warnings | |
warnings.filterwarnings("ignore") | |
from typing import List | |
import pandas as pd | |
from tqdm import tqdm | |
import stockstats | |
import talib | |
from meta.data_processors._base import _Base | |
import akshare as ak # pip install akshare | |
class Akshare(_Base): | |
def __init__( | |
self, | |
data_source: str, | |
start_date: str, | |
end_date: str, | |
time_interval: str, | |
**kwargs, | |
): | |
start_date = self.transfer_date(start_date) | |
end_date = self.transfer_date(end_date) | |
super().__init__(data_source, start_date, end_date, time_interval, **kwargs) | |
if "adj" in kwargs.keys(): | |
self.adj = kwargs["adj"] | |
print(f"Using {self.adj} method.") | |
else: | |
self.adj = "" | |
if "period" in kwargs.keys(): | |
self.period = kwargs["period"] | |
else: | |
self.period = "daily" | |
def get_data(self, id) -> pd.DataFrame: | |
return ak.stock_zh_a_hist( | |
symbol=id, | |
period=self.time_interval, | |
start_date=self.start_date, | |
end_date=self.end_date, | |
adjust=self.adj, | |
) | |
def download_data( | |
self, ticker_list: List[str], save_path: str = "./data/dataset.csv" | |
): | |
""" | |
`pd.DataFrame` | |
7 columns: A tick symbol, time, open, high, low, close and volume | |
for the specified stock ticker | |
""" | |
assert self.time_interval in [ | |
"daily", | |
"weekly", | |
"monthly", | |
], "Not supported currently" | |
self.ticker_list = ticker_list | |
self.dataframe = pd.DataFrame() | |
for i in tqdm(ticker_list, total=len(ticker_list)): | |
nonstandard_id = self.transfer_standard_ticker_to_nonstandard(i) | |
df_temp = self.get_data(nonstandard_id) | |
df_temp["tic"] = i | |
# df_temp = self.get_data(i) | |
self.dataframe = pd.concat([self.dataframe, df_temp]) | |
# self.dataframe = self.dataframe.append(df_temp) | |
# print("{} ok".format(i)) | |
time.sleep(0.25) | |
self.dataframe.columns = [ | |
"time", | |
"open", | |
"close", | |
"high", | |
"low", | |
"volume", | |
"amount", | |
"amplitude", | |
"pct_chg", | |
"change", | |
"turnover", | |
"tic", | |
] | |
self.dataframe.sort_values(by=["time", "tic"], inplace=True) | |
self.dataframe.reset_index(drop=True, inplace=True) | |
self.dataframe = self.dataframe[ | |
["tic", "time", "open", "high", "low", "close", "volume"] | |
] | |
# self.dataframe.loc[:, 'tic'] = pd.DataFrame((self.dataframe['tic'].tolist())) | |
self.dataframe["time"] = pd.to_datetime( | |
self.dataframe["time"], format="%Y-%m-%d" | |
) | |
self.dataframe["day"] = self.dataframe["time"].dt.dayofweek | |
self.dataframe["time"] = self.dataframe.time.apply( | |
lambda x: x.strftime("%Y-%m-%d") | |
) | |
self.dataframe.dropna(inplace=True) | |
self.dataframe.sort_values(by=["time", "tic"], inplace=True) | |
self.dataframe.reset_index(drop=True, inplace=True) | |
self.save_data(save_path) | |
print( | |
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}" | |
) | |
def data_split(self, df, start, end, target_date_col="time"): | |
""" | |
split the dataset into training or testing using time | |
:param data: (df) pandas dataframe, start, end | |
:return: (df) pandas dataframe | |
""" | |
data = df[(df[target_date_col] >= start) & (df[target_date_col] < end)] | |
data = data.sort_values([target_date_col, "tic"], ignore_index=True) | |
data.index = data[target_date_col].factorize()[0] | |
return data | |
def transfer_standard_ticker_to_nonstandard(self, ticker: str) -> str: | |
# "600000.XSHG" -> "600000" | |
# "000612.XSHE" -> "000612" | |
# "600000.SH" -> "600000" | |
# "000612.SZ" -> "000612" | |
if "." in ticker: | |
n, alpha = ticker.split(".") | |
# assert alpha in ["XSHG", "XSHE"], "Wrong alpha" | |
return n | |
def transfer_date(self, time: str) -> str: | |
if "-" in time: | |
time = "".join(time.split("-")) | |
elif "." in time: | |
time = "".join(time.split(".")) | |
elif "/" in time: | |
time = "".join(time.split("/")) | |
return time | |