Krisseck's picture
Only run predictions
657a68a
import math
import os
import tempfile
import csv
import pandas as pd
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
from transformers.integrations import INTEGRATION_TO_CALLBACK
from tsfm_public import TimeSeriesPreprocessor, TrackingCallback, count_parameters, get_datasets
from tsfm_public.toolkit.get_model import get_model
from tsfm_public.toolkit.lr_finder import optimal_lr_finder
import warnings
# Suppress all warnings
warnings.filterwarnings("ignore")
# Set seed for reproducibility
SEED = 42
set_seed(SEED)
# TTM Model path. The default model path is Granite-R2. Below, you can choose other TTM releases.
TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r2"
# TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r1"
# TTM_MODEL_PATH = "ibm-research/ttm-research-r2"
# Context length, Or Length of the history.
# Currently supported values are: 512/1024/1536 for Granite-TTM-R2 and Research-Use-TTM-R2, and 512/1024 for Granite-TTM-R1
CONTEXT_LENGTH = 512
# Granite-TTM-R2 supports forecast length upto 720 and Granite-TTM-R1 supports forecast length upto 96
PREDICTION_LENGTH = 96
# Results dir
OUT_DIR = "ttm_finetuned_models/"
# Dataset
TARGET_DATASET = "binance-btcusdt-futures-2020-2021-1s"
dataset_path = "./test.csv"
timestamp_column = "timestamp"
id_columns = [] # mention the ids that uniquely identify a time-series.
target_columns = ["bid"]
split_config = {
"train": 0.1,
"test": 0.9
}
# Understanding the split config -- slides
data = pd.read_csv(
dataset_path,
parse_dates=[timestamp_column],
header=0
)
column_specifiers = {
"timestamp_column": timestamp_column,
"id_columns": id_columns,
"target_columns": target_columns,
"control_columns": [],
}
def zeroshot_eval(dataset_name, batch_size, context_length=512, forecast_length=96):
# Get data
tsp = TimeSeriesPreprocessor(
**column_specifiers,
context_length=context_length,
prediction_length=forecast_length,
scaling=True,
encode_categorical=False,
scaler_type="standard",
)
dset_train, dset_valid, dset_test = get_datasets(tsp, data, split_config)
# Load model
zeroshot_model = get_model(TTM_MODEL_PATH, context_length=context_length, prediction_length=forecast_length)
temp_dir = tempfile.mkdtemp()
# zeroshot_trainer
zeroshot_trainer = Trainer(
model=zeroshot_model,
args=TrainingArguments(
output_dir=temp_dir,
per_device_eval_batch_size=batch_size,
seed=SEED,
report_to="none",
),
)
# train predictions
print("+" * 20, "Train predict zero-shot", "+" * 20)
predictions_dict = zeroshot_trainer.predict(dset_train)
predictions_np_train = predictions_dict.predictions[0]
# test predictions
print("+" * 20, "Test predict zero-shot", "+" * 20)
predictions_dict = zeroshot_trainer.predict(dset_test)
predictions_np_test = predictions_dict.predictions[0]
with open('results.csv', 'w', newline='') as csvfile:
writer = csv.writer(csvfile, delimiter=',')
for i in range(len(dset_train)):
writer.writerow([
dset_train[i]['timestamp'],
dset_train[i]['past_values'][CONTEXT_LENGTH-1][0].detach().item(),
predictions_np_train[i][PREDICTION_LENGTH-1][0],
dset_train[i]['future_values'][PREDICTION_LENGTH-1][0].detach().item()
])
for i in range(len(dset_test)):
writer.writerow([
dset_test[i]['timestamp'],
dset_test[i]['past_values'][CONTEXT_LENGTH-1][0].detach().item(),
predictions_np_test[i][PREDICTION_LENGTH-1][0],
dset_test[i]['future_values'][PREDICTION_LENGTH-1][0].detach().item()
])
# get backbone embeddings (if needed for further analysis)
#backbone_embedding = predictions_dict.predictions[1]
#print(backbone_embedding.shape)
zeroshot_eval(
dataset_name=TARGET_DATASET, context_length=CONTEXT_LENGTH, forecast_length=PREDICTION_LENGTH, batch_size=128
)