|
import math |
|
import os |
|
import tempfile |
|
import csv |
|
|
|
import pandas as pd |
|
from torch.optim import AdamW |
|
from torch.optim.lr_scheduler import OneCycleLR |
|
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed |
|
from transformers.integrations import INTEGRATION_TO_CALLBACK |
|
|
|
from tsfm_public import TimeSeriesPreprocessor, TrackingCallback, count_parameters, get_datasets |
|
from tsfm_public.toolkit.get_model import get_model |
|
from tsfm_public.toolkit.lr_finder import optimal_lr_finder |
|
|
|
import warnings |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
SEED = 42 |
|
set_seed(SEED) |
|
|
|
|
|
TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r2" |
|
|
|
|
|
|
|
|
|
|
|
CONTEXT_LENGTH = 512 |
|
|
|
|
|
PREDICTION_LENGTH = 96 |
|
|
|
|
|
OUT_DIR = "ttm_finetuned_models/" |
|
|
|
|
|
TARGET_DATASET = "binance-btcusdt-futures-2020-2021-1s" |
|
dataset_path = "./test.csv" |
|
timestamp_column = "timestamp" |
|
id_columns = [] |
|
|
|
target_columns = ["bid"] |
|
split_config = { |
|
"train": 0.1, |
|
"test": 0.9 |
|
} |
|
|
|
|
|
data = pd.read_csv( |
|
dataset_path, |
|
parse_dates=[timestamp_column], |
|
header=0 |
|
) |
|
|
|
column_specifiers = { |
|
"timestamp_column": timestamp_column, |
|
"id_columns": id_columns, |
|
"target_columns": target_columns, |
|
"control_columns": [], |
|
} |
|
|
|
def zeroshot_eval(dataset_name, batch_size, context_length=512, forecast_length=96): |
|
|
|
|
|
tsp = TimeSeriesPreprocessor( |
|
**column_specifiers, |
|
context_length=context_length, |
|
prediction_length=forecast_length, |
|
scaling=True, |
|
encode_categorical=False, |
|
scaler_type="standard", |
|
) |
|
|
|
dset_train, dset_valid, dset_test = get_datasets(tsp, data, split_config) |
|
|
|
|
|
zeroshot_model = get_model(TTM_MODEL_PATH, context_length=context_length, prediction_length=forecast_length) |
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
|
zeroshot_trainer = Trainer( |
|
model=zeroshot_model, |
|
args=TrainingArguments( |
|
output_dir=temp_dir, |
|
per_device_eval_batch_size=batch_size, |
|
seed=SEED, |
|
report_to="none", |
|
), |
|
) |
|
|
|
|
|
|
|
print("+" * 20, "Train predict zero-shot", "+" * 20) |
|
predictions_dict = zeroshot_trainer.predict(dset_train) |
|
|
|
predictions_np_train = predictions_dict.predictions[0] |
|
|
|
|
|
|
|
print("+" * 20, "Test predict zero-shot", "+" * 20) |
|
predictions_dict = zeroshot_trainer.predict(dset_test) |
|
|
|
predictions_np_test = predictions_dict.predictions[0] |
|
|
|
with open('results.csv', 'w', newline='') as csvfile: |
|
|
|
writer = csv.writer(csvfile, delimiter=',') |
|
|
|
for i in range(len(dset_train)): |
|
writer.writerow([ |
|
dset_train[i]['timestamp'], |
|
dset_train[i]['past_values'][CONTEXT_LENGTH-1][0].detach().item(), |
|
predictions_np_train[i][PREDICTION_LENGTH-1][0], |
|
dset_train[i]['future_values'][PREDICTION_LENGTH-1][0].detach().item() |
|
]) |
|
|
|
for i in range(len(dset_test)): |
|
writer.writerow([ |
|
dset_test[i]['timestamp'], |
|
dset_test[i]['past_values'][CONTEXT_LENGTH-1][0].detach().item(), |
|
predictions_np_test[i][PREDICTION_LENGTH-1][0], |
|
dset_test[i]['future_values'][PREDICTION_LENGTH-1][0].detach().item() |
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
zeroshot_eval( |
|
dataset_name=TARGET_DATASET, context_length=CONTEXT_LENGTH, forecast_length=PREDICTION_LENGTH, batch_size=128 |
|
) |
|
|