Initial commit
Browse files- .gitignore +1 -0
- data-processing/.gitignore +3 -0
- data-processing/convert.js +57 -0
- data-processing/granite-ttm.py +145 -0
.gitignore
CHANGED
@@ -5,6 +5,7 @@ __pycache__/
|
|
5 |
.ipynb_checkpoints
|
6 |
*ipynb
|
7 |
.vscode/
|
|
|
8 |
|
9 |
eval-queue/
|
10 |
eval-results/
|
|
|
5 |
.ipynb_checkpoints
|
6 |
*ipynb
|
7 |
.vscode/
|
8 |
+
.idea
|
9 |
|
10 |
eval-queue/
|
11 |
eval-results/
|
data-processing/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
binance-btcusdt-futures-2020-2021
|
2 |
+
venv
|
3 |
+
test.*
|
data-processing/convert.js
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fs from 'node:fs/promises';
|
2 |
+
import * as zlib from "zlib";
|
3 |
+
import readline from 'node:readline';
|
4 |
+
|
5 |
+
let currentTime = 1593561599000;
|
6 |
+
|
7 |
+
await fs.appendFile("test.csv", "timestamp,bid\n");
|
8 |
+
|
9 |
+
for await (const dayFile of fs.glob("./binance-btcusdt-futures-2020-2021/*.csv.gz")) {
|
10 |
+
console.log("Reading", dayFile, "...");
|
11 |
+
|
12 |
+
let outputRows = [];
|
13 |
+
|
14 |
+
const fd = await fs.open(dayFile);
|
15 |
+
|
16 |
+
let lineReader = readline.createInterface({
|
17 |
+
input: fd.createReadStream().pipe(zlib.createGunzip())
|
18 |
+
});
|
19 |
+
|
20 |
+
let n = 0;
|
21 |
+
for await (const line of lineReader) {
|
22 |
+
if (n > 0) {
|
23 |
+
let lineParts = line.split(',');
|
24 |
+
let timestamp = parseInt(lineParts[1]);
|
25 |
+
if (timestamp >= currentTime + 1000) {
|
26 |
+
currentTime = Math.floor(timestamp / 1000) * 1000;
|
27 |
+
outputRows.push([unixTime(currentTime), lineParts[3]]);
|
28 |
+
}
|
29 |
+
}
|
30 |
+
n++;
|
31 |
+
}
|
32 |
+
|
33 |
+
console.log("Done");
|
34 |
+
|
35 |
+
let output = "";
|
36 |
+
|
37 |
+
outputRows.forEach((row) => {
|
38 |
+
output += row.join(",") + "\n";
|
39 |
+
});
|
40 |
+
|
41 |
+
await fs.appendFile("test.csv", output);
|
42 |
+
|
43 |
+
lineReader.close();
|
44 |
+
}
|
45 |
+
|
46 |
+
|
47 |
+
function unixTime(unixtime) {
|
48 |
+
|
49 |
+
var u = new Date(unixtime);
|
50 |
+
|
51 |
+
return u.getUTCFullYear() +
|
52 |
+
'-' + ('0' + (u.getUTCMonth() + 1)).slice(-2) +
|
53 |
+
'-' + ('0' + u.getUTCDate()).slice(-2) +
|
54 |
+
' ' + ('0' + u.getUTCHours()).slice(-2) +
|
55 |
+
':' + ('0' + u.getUTCMinutes()).slice(-2) +
|
56 |
+
':' + ('0' + u.getUTCSeconds()).slice(-2);
|
57 |
+
}
|
data-processing/granite-ttm.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
import csv
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
from torch.optim import AdamW
|
8 |
+
from torch.optim.lr_scheduler import OneCycleLR
|
9 |
+
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
|
10 |
+
from transformers.integrations import INTEGRATION_TO_CALLBACK
|
11 |
+
|
12 |
+
from tsfm_public import TimeSeriesPreprocessor, TrackingCallback, count_parameters, get_datasets
|
13 |
+
from tsfm_public.toolkit.get_model import get_model
|
14 |
+
from tsfm_public.toolkit.lr_finder import optimal_lr_finder
|
15 |
+
|
16 |
+
import warnings
|
17 |
+
|
18 |
+
# Suppress all warnings
|
19 |
+
warnings.filterwarnings("ignore")
|
20 |
+
|
21 |
+
# Set seed for reproducibility
|
22 |
+
SEED = 42
|
23 |
+
set_seed(SEED)
|
24 |
+
|
25 |
+
# TTM Model path. The default model path is Granite-R2. Below, you can choose other TTM releases.
|
26 |
+
TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r2"
|
27 |
+
# TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r1"
|
28 |
+
# TTM_MODEL_PATH = "ibm-research/ttm-research-r2"
|
29 |
+
|
30 |
+
# Context length, Or Length of the history.
|
31 |
+
# Currently supported values are: 512/1024/1536 for Granite-TTM-R2 and Research-Use-TTM-R2, and 512/1024 for Granite-TTM-R1
|
32 |
+
CONTEXT_LENGTH = 512
|
33 |
+
|
34 |
+
# Granite-TTM-R2 supports forecast length upto 720 and Granite-TTM-R1 supports forecast length upto 96
|
35 |
+
PREDICTION_LENGTH = 96
|
36 |
+
|
37 |
+
# Results dir
|
38 |
+
OUT_DIR = "ttm_finetuned_models/"
|
39 |
+
|
40 |
+
# Dataset
|
41 |
+
TARGET_DATASET = "binance-btcusdt-futures-2020-2021-1s"
|
42 |
+
dataset_path = "./test.csv"
|
43 |
+
timestamp_column = "timestamp"
|
44 |
+
id_columns = [] # mention the ids that uniquely identify a time-series.
|
45 |
+
|
46 |
+
target_columns = ["bid"]
|
47 |
+
split_config = {
|
48 |
+
"train": 0.1,
|
49 |
+
"test": 0.9
|
50 |
+
}
|
51 |
+
# Understanding the split config -- slides
|
52 |
+
|
53 |
+
data = pd.read_csv(
|
54 |
+
dataset_path,
|
55 |
+
parse_dates=[timestamp_column],
|
56 |
+
header=0
|
57 |
+
)
|
58 |
+
|
59 |
+
column_specifiers = {
|
60 |
+
"timestamp_column": timestamp_column,
|
61 |
+
"id_columns": id_columns,
|
62 |
+
"target_columns": target_columns,
|
63 |
+
"control_columns": [],
|
64 |
+
}
|
65 |
+
|
66 |
+
def zeroshot_eval(dataset_name, batch_size, context_length=512, forecast_length=96):
|
67 |
+
# Get data
|
68 |
+
|
69 |
+
tsp = TimeSeriesPreprocessor(
|
70 |
+
**column_specifiers,
|
71 |
+
context_length=context_length,
|
72 |
+
prediction_length=forecast_length,
|
73 |
+
scaling=True,
|
74 |
+
encode_categorical=False,
|
75 |
+
scaler_type="standard",
|
76 |
+
)
|
77 |
+
|
78 |
+
dset_train, dset_valid, dset_test = get_datasets(tsp, data, split_config)
|
79 |
+
|
80 |
+
# Load model
|
81 |
+
zeroshot_model = get_model(TTM_MODEL_PATH, context_length=context_length, prediction_length=forecast_length)
|
82 |
+
|
83 |
+
temp_dir = tempfile.mkdtemp()
|
84 |
+
# zeroshot_trainer
|
85 |
+
zeroshot_trainer = Trainer(
|
86 |
+
model=zeroshot_model,
|
87 |
+
args=TrainingArguments(
|
88 |
+
output_dir=temp_dir,
|
89 |
+
per_device_eval_batch_size=batch_size,
|
90 |
+
seed=SEED,
|
91 |
+
report_to="none",
|
92 |
+
),
|
93 |
+
)
|
94 |
+
|
95 |
+
# train predictions
|
96 |
+
|
97 |
+
print("+" * 20, "Train MSE zero-shot", "+" * 20)
|
98 |
+
zeroshot_output = zeroshot_trainer.evaluate(dset_train)
|
99 |
+
print(zeroshot_output)
|
100 |
+
|
101 |
+
predictions_dict = zeroshot_trainer.predict(dset_train)
|
102 |
+
|
103 |
+
predictions_np_train = predictions_dict.predictions[0]
|
104 |
+
|
105 |
+
# test predictions
|
106 |
+
|
107 |
+
print("+" * 20, "Test MSE zero-shot", "+" * 20)
|
108 |
+
zeroshot_output = zeroshot_trainer.evaluate(dset_test)
|
109 |
+
print(zeroshot_output)
|
110 |
+
|
111 |
+
predictions_dict = zeroshot_trainer.predict(dset_test)
|
112 |
+
|
113 |
+
predictions_np_test = predictions_dict.predictions[0]
|
114 |
+
|
115 |
+
with open('results.csv', 'w', newline='') as csvfile:
|
116 |
+
|
117 |
+
writer = csv.writer(csvfile, delimiter=',')
|
118 |
+
|
119 |
+
for i in range(len(dset_train)):
|
120 |
+
writer.writerow([
|
121 |
+
dset_train[i]['timestamp'],
|
122 |
+
dset_train[i]['past_values'][CONTEXT_LENGTH-1][0].detach().item(),
|
123 |
+
predictions_np_train[i][PREDICTION_LENGTH-1][0],
|
124 |
+
dset_train[i]['future_values'][PREDICTION_LENGTH-1][0].detach().item()
|
125 |
+
])
|
126 |
+
|
127 |
+
for i in range(len(dset_test)):
|
128 |
+
writer.writerow([
|
129 |
+
dset_test[i]['timestamp'],
|
130 |
+
dset_test[i]['past_values'][CONTEXT_LENGTH-1][0].detach().item(),
|
131 |
+
predictions_np_test[i][PREDICTION_LENGTH-1][0],
|
132 |
+
dset_test[i]['future_values'][PREDICTION_LENGTH-1][0].detach().item()
|
133 |
+
])
|
134 |
+
|
135 |
+
|
136 |
+
# get backbone embeddings (if needed for further analysis)
|
137 |
+
|
138 |
+
#backbone_embedding = predictions_dict.predictions[1]
|
139 |
+
|
140 |
+
#print(backbone_embedding.shape)
|
141 |
+
|
142 |
+
|
143 |
+
zeroshot_eval(
|
144 |
+
dataset_name=TARGET_DATASET, context_length=CONTEXT_LENGTH, forecast_length=PREDICTION_LENGTH, batch_size=128
|
145 |
+
)
|