from lib import * from twokenize import tokenizeRawTweetText import re def muscaps_tokenize(raw): raw = raw.lower() for punc in string.punctuation: raw = raw.replace(punc, ' ') tokens = raw.split() return tokens def get_device(device_id: int) -> torch.device: if not torch.cuda.is_available(): return torch.device('cpu') device_id = min(torch.cuda.device_count() - 1, device_id) return torch.device(f'cuda:{device_id}') def preproc(caption, tokenizer, stop=True): caption = caption.replace('.', '') caption_proc = tokenizer.encode(caption) if stop: caption_proc += tokenizer.encode('.') return caption_proc def postproc(caption): caption = caption.replace('', '.') if caption[-1] == '.': caption = caption[:-1] return caption class CheckpointManager: def __init__(self): self.checkpoint_dir = '/home/nsrivats/Repositories/MusicCaptioning/checkpoints' def get_checkpoint(self, checkpoint): with open(checkpoint, 'rb') as infile: return torch.load(infile) def save_checkpoint(self, state_dict, checkpoint): filename = f'{self.checkpoint_dir}/{checkpoint}' with open(filename, 'wb') as outfile: torch.save(state_dict, outfile) def save_logs(self, logdir): pass