import torch, json class CharTokenizer: def __init__(self, corpus=None, vocab=None): if vocab is not None: self.vocab = vocab elif corpus is not None: self.vocab = self._build_vocab(corpus) else: raise Exception("Either corpus or vocab has to be supplied") self.id2vocab = [char for char, index in sorted(self.vocab.items(), key=lambda item: item[1])] def _tokenize(self, text): return list(text) def __call__(self, prompt, text=None, add_eos_token=False): token_ids = [self.vocab.get(token, 0) for token in self._tokenize(prompt)] if text is not None: text_token_ids = [self.vocab.get(token, 0) for token in self._tokenize(text)] token_ids = token_ids + [self.vocab[""]] + text_token_ids if add_eos_token: token_ids = token_ids + [self.vocab[""]] input_ids_tensor = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) attention_masks = torch.ones_like(input_ids_tensor) return {"input_ids": input_ids_tensor, "attention_mask": attention_masks} def _build_vocab(self, corpus): vocab = {"": 0} for verse_lengths in range(3, 10): vocab[str(verse_lengths)] = len(vocab) for doc in corpus: chars = self._tokenize(doc) for char in chars: if char not in vocab: vocab[char] = len(vocab) vocab[""] = len(vocab) vocab[""] = len(vocab) return vocab def decode(self, token_ids): chars = [self.id2vocab[token_id] for token_id in token_ids.flatten().tolist()] filtered_chars = [char for char in chars if char not in ["", "", ""]] return "".join(filtered_chars) def save(self, filepath): with open(filepath, "w") as f: json.dump(self.vocab, f) @classmethod def load(cls, filepath): with open(filepath) as f: vocab = json.load(f) return cls(vocab=vocab)