import os import orjson import torch import numpy as np from model import TMR_textencoder EMBS = "data/unit_motion_embs" def load_json(path): with open(path, "rb") as ff: return orjson.loads(ff.read()) def load_keyids(split): path = os.path.join(EMBS, f"{split}.keyids") with open(path) as ff: keyids = np.array([x.strip() for x in ff.readlines()]) return keyids def load_keyids_splits(splits): return {split: load_keyids(split) for split in splits} def load_unit_motion_embs(split, device): path = os.path.join(EMBS, f"{split}_motion_embs_unit.npy") tensor = torch.from_numpy(np.load(path)).to(device) return tensor def load_unit_motion_embs_splits(splits, device): return {split: load_unit_motion_embs(split, device) for split in splits} def load_model(device): text_params = { "latent_dim": 256, "ff_size": 1024, "num_layers": 6, "num_heads": 4, "activation": "gelu", "modelpath": "distilbert-base-uncased", } "unit_motion_embs" model = TMR_textencoder(**text_params) state_dict = torch.load("data/textencoder.pt", map_location=device) # load values for the transformer only model.load_state_dict(state_dict, strict=False) model = model.eval() return model.to(device)