File size: 2,226 Bytes
de6e775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import copy
import datetime
import os
from typing import List

import jqdatasdk as jq
import numpy as np

from meta.data_processors._base import _Base


class Joinquant(_Base):
    def __init__(
        self,
        data_source: str,
        start_date: str,
        end_date: str,
        time_interval: str,
        **kwargs,
    ):
        super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
        if "username" in kwargs.keys() and "password" in kwargs.keys():
            jq.auth(kwargs["username"], kwargs["password"])

    def download_data(
        self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
    ):
        # joinquant supports: '1m', '5m', '15m', '30m', '60m', '120m', '1d', '1w', '1M'。'1w' denotes one week,‘1M' denotes one month。
        count = len(self.get_trading_days(self.start_date, self.end_date))
        df = jq.get_bars(
            security=ticker_list,
            count=count,
            unit=self.time_interval,
            fields=["date", "open", "high", "low", "close", "volume"],
            end_dt=self.end_date,
        )
        df = df.reset_index().rename(columns={"level_0": "tic"})
        self.dataframe = df

        self.save_data(save_path)

        print(
            f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
        )

    def preprocess(df, stock_list):
        n = len(stock_list)
        N = df.shape[0]
        assert N % n == 0
        d = int(N / n)
        stock1_ary = df.iloc[0:d, 1:].values
        temp_ary = stock1_ary
        for j in range(1, n):
            stocki_ary = df.iloc[j * d : (j + 1) * d, 1:].values
            temp_ary = np.hstack((temp_ary, stocki_ary))
        return temp_ary

    # start_day: str
    # end_day: str
    # output: list of str_of_trade_day, e.g., ['2021-09-01', '2021-09-02']
    def get_trading_days(self, start_day: str, end_day: str) -> List[str]:
        dates = jq.get_trade_days(start_day, end_day)
        str_dates = []
        for d in dates:
            tmp = datetime.date.strftime(d, "%Y-%m-%d")
            str_dates.append(tmp)
        # str_dates = [date2str(dt) for dt in dates]
        return str_dates