Krisseck commited on
Commit
23b39f0
·
1 Parent(s): 4b2fda1

Initial commit

Browse files
.gitignore CHANGED
@@ -5,6 +5,7 @@ __pycache__/
5
  .ipynb_checkpoints
6
  *ipynb
7
  .vscode/
 
8
 
9
  eval-queue/
10
  eval-results/
 
5
  .ipynb_checkpoints
6
  *ipynb
7
  .vscode/
8
+ .idea
9
 
10
  eval-queue/
11
  eval-results/
data-processing/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ binance-btcusdt-futures-2020-2021
2
+ venv
3
+ test.*
data-processing/convert.js ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fs from 'node:fs/promises';
2
+ import * as zlib from "zlib";
3
+ import readline from 'node:readline';
4
+
5
+ let currentTime = 1593561599000;
6
+
7
+ await fs.appendFile("test.csv", "timestamp,bid\n");
8
+
9
+ for await (const dayFile of fs.glob("./binance-btcusdt-futures-2020-2021/*.csv.gz")) {
10
+ console.log("Reading", dayFile, "...");
11
+
12
+ let outputRows = [];
13
+
14
+ const fd = await fs.open(dayFile);
15
+
16
+ let lineReader = readline.createInterface({
17
+ input: fd.createReadStream().pipe(zlib.createGunzip())
18
+ });
19
+
20
+ let n = 0;
21
+ for await (const line of lineReader) {
22
+ if (n > 0) {
23
+ let lineParts = line.split(',');
24
+ let timestamp = parseInt(lineParts[1]);
25
+ if (timestamp >= currentTime + 1000) {
26
+ currentTime = Math.floor(timestamp / 1000) * 1000;
27
+ outputRows.push([unixTime(currentTime), lineParts[3]]);
28
+ }
29
+ }
30
+ n++;
31
+ }
32
+
33
+ console.log("Done");
34
+
35
+ let output = "";
36
+
37
+ outputRows.forEach((row) => {
38
+ output += row.join(",") + "\n";
39
+ });
40
+
41
+ await fs.appendFile("test.csv", output);
42
+
43
+ lineReader.close();
44
+ }
45
+
46
+
47
+ function unixTime(unixtime) {
48
+
49
+ var u = new Date(unixtime);
50
+
51
+ return u.getUTCFullYear() +
52
+ '-' + ('0' + (u.getUTCMonth() + 1)).slice(-2) +
53
+ '-' + ('0' + u.getUTCDate()).slice(-2) +
54
+ ' ' + ('0' + u.getUTCHours()).slice(-2) +
55
+ ':' + ('0' + u.getUTCMinutes()).slice(-2) +
56
+ ':' + ('0' + u.getUTCSeconds()).slice(-2);
57
+ }
data-processing/granite-ttm.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import tempfile
4
+ import csv
5
+
6
+ import pandas as pd
7
+ from torch.optim import AdamW
8
+ from torch.optim.lr_scheduler import OneCycleLR
9
+ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
10
+ from transformers.integrations import INTEGRATION_TO_CALLBACK
11
+
12
+ from tsfm_public import TimeSeriesPreprocessor, TrackingCallback, count_parameters, get_datasets
13
+ from tsfm_public.toolkit.get_model import get_model
14
+ from tsfm_public.toolkit.lr_finder import optimal_lr_finder
15
+
16
+ import warnings
17
+
18
+ # Suppress all warnings
19
+ warnings.filterwarnings("ignore")
20
+
21
+ # Set seed for reproducibility
22
+ SEED = 42
23
+ set_seed(SEED)
24
+
25
+ # TTM Model path. The default model path is Granite-R2. Below, you can choose other TTM releases.
26
+ TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r2"
27
+ # TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r1"
28
+ # TTM_MODEL_PATH = "ibm-research/ttm-research-r2"
29
+
30
+ # Context length, Or Length of the history.
31
+ # Currently supported values are: 512/1024/1536 for Granite-TTM-R2 and Research-Use-TTM-R2, and 512/1024 for Granite-TTM-R1
32
+ CONTEXT_LENGTH = 512
33
+
34
+ # Granite-TTM-R2 supports forecast length upto 720 and Granite-TTM-R1 supports forecast length upto 96
35
+ PREDICTION_LENGTH = 96
36
+
37
+ # Results dir
38
+ OUT_DIR = "ttm_finetuned_models/"
39
+
40
+ # Dataset
41
+ TARGET_DATASET = "binance-btcusdt-futures-2020-2021-1s"
42
+ dataset_path = "./test.csv"
43
+ timestamp_column = "timestamp"
44
+ id_columns = [] # mention the ids that uniquely identify a time-series.
45
+
46
+ target_columns = ["bid"]
47
+ split_config = {
48
+ "train": 0.1,
49
+ "test": 0.9
50
+ }
51
+ # Understanding the split config -- slides
52
+
53
+ data = pd.read_csv(
54
+ dataset_path,
55
+ parse_dates=[timestamp_column],
56
+ header=0
57
+ )
58
+
59
+ column_specifiers = {
60
+ "timestamp_column": timestamp_column,
61
+ "id_columns": id_columns,
62
+ "target_columns": target_columns,
63
+ "control_columns": [],
64
+ }
65
+
66
+ def zeroshot_eval(dataset_name, batch_size, context_length=512, forecast_length=96):
67
+ # Get data
68
+
69
+ tsp = TimeSeriesPreprocessor(
70
+ **column_specifiers,
71
+ context_length=context_length,
72
+ prediction_length=forecast_length,
73
+ scaling=True,
74
+ encode_categorical=False,
75
+ scaler_type="standard",
76
+ )
77
+
78
+ dset_train, dset_valid, dset_test = get_datasets(tsp, data, split_config)
79
+
80
+ # Load model
81
+ zeroshot_model = get_model(TTM_MODEL_PATH, context_length=context_length, prediction_length=forecast_length)
82
+
83
+ temp_dir = tempfile.mkdtemp()
84
+ # zeroshot_trainer
85
+ zeroshot_trainer = Trainer(
86
+ model=zeroshot_model,
87
+ args=TrainingArguments(
88
+ output_dir=temp_dir,
89
+ per_device_eval_batch_size=batch_size,
90
+ seed=SEED,
91
+ report_to="none",
92
+ ),
93
+ )
94
+
95
+ # train predictions
96
+
97
+ print("+" * 20, "Train MSE zero-shot", "+" * 20)
98
+ zeroshot_output = zeroshot_trainer.evaluate(dset_train)
99
+ print(zeroshot_output)
100
+
101
+ predictions_dict = zeroshot_trainer.predict(dset_train)
102
+
103
+ predictions_np_train = predictions_dict.predictions[0]
104
+
105
+ # test predictions
106
+
107
+ print("+" * 20, "Test MSE zero-shot", "+" * 20)
108
+ zeroshot_output = zeroshot_trainer.evaluate(dset_test)
109
+ print(zeroshot_output)
110
+
111
+ predictions_dict = zeroshot_trainer.predict(dset_test)
112
+
113
+ predictions_np_test = predictions_dict.predictions[0]
114
+
115
+ with open('results.csv', 'w', newline='') as csvfile:
116
+
117
+ writer = csv.writer(csvfile, delimiter=',')
118
+
119
+ for i in range(len(dset_train)):
120
+ writer.writerow([
121
+ dset_train[i]['timestamp'],
122
+ dset_train[i]['past_values'][CONTEXT_LENGTH-1][0].detach().item(),
123
+ predictions_np_train[i][PREDICTION_LENGTH-1][0],
124
+ dset_train[i]['future_values'][PREDICTION_LENGTH-1][0].detach().item()
125
+ ])
126
+
127
+ for i in range(len(dset_test)):
128
+ writer.writerow([
129
+ dset_test[i]['timestamp'],
130
+ dset_test[i]['past_values'][CONTEXT_LENGTH-1][0].detach().item(),
131
+ predictions_np_test[i][PREDICTION_LENGTH-1][0],
132
+ dset_test[i]['future_values'][PREDICTION_LENGTH-1][0].detach().item()
133
+ ])
134
+
135
+
136
+ # get backbone embeddings (if needed for further analysis)
137
+
138
+ #backbone_embedding = predictions_dict.predictions[1]
139
+
140
+ #print(backbone_embedding.shape)
141
+
142
+
143
+ zeroshot_eval(
144
+ dataset_name=TARGET_DATASET, context_length=CONTEXT_LENGTH, forecast_length=PREDICTION_LENGTH, batch_size=128
145
+ )