Spaces:
Runtime error
Runtime error
File size: 7,788 Bytes
7edceed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
# -*- coding: utf-8 -*-
# ๐พโ๏ธ๐ฎ
# taken from https://github.com/Felflare/rpunct/blob/master/rpunct/punctuate.py
# modified to support batching during gpu inference
__author__ = "Daulet N."
__email__ = "daulet.nurmanbetov@gmail.com"
import time
import logging
import webvtt
import torch
from io import StringIO
from nltk.tokenize import sent_tokenize
#from langdetect import detect
from simpletransformers.ner import NERModel
class RestorePuncts:
def __init__(self, wrds_per_pred=250):
self.wrds_per_pred = wrds_per_pred
self.overlap_wrds = 30
self.valid_labels = ['OU', 'OO', '.O', '!O', ',O', '.U', '!U', ',U', ':O', ';O', ':U', "'O", '-O', '?O', '?U']
self.model = NERModel("bert", "felflare/bert-restore-punctuation", labels=self.valid_labels,
args={"silent": True, "max_seq_length": 512})
# use_cuda isnt working and this hack seems to load the model correctly to the gpu
self.model.device = torch.device("cuda:1")
# dummy punctuate to load the model onto gpu
self.punctuate("hello how are you")
def punctuate(self, text: str, batch_size:int=32, lang:str=''):
"""
Performs punctuation restoration on arbitrarily large text.
Detects if input is not English, if non-English was detected terminates predictions.
Overrride by supplying `lang='en'`
Args:
- text (str): Text to punctuate, can be few words to as large as you want.
- lang (str): Explicit language of input text.
"""
#if not lang and len(text) > 10:
# lang = detect(text)
#if lang != 'en':
# raise Exception(F"""Non English text detected. Restore Punctuation works only for English.
# If you are certain the input is English, pass argument lang='en' to this function.
# Punctuate received: {text}""")
def chunks(L, n):
return [L[x : x + n] for x in range(0, len(L), n)]
# plit up large text into bert digestable chunks
splits = self.split_on_toks(text, self.wrds_per_pred, self.overlap_wrds)
texts = [i["text"] for i in splits]
batches = chunks(texts, batch_size)
preds_lst = []
for batch in batches:
batch_preds, _ = self.model.predict(batch)
preds_lst.extend(batch_preds)
# predict slices
# full_preds_lst contains tuple of labels and logits
#full_preds_lst = [self.predict(i['text']) for i in splits]
# extract predictions, and discard logits
#preds_lst = [i[0][0] for i in full_preds_lst]
# join text slices
combined_preds = self.combine_results(text, preds_lst)
# create punctuated prediction
punct_text = self.punctuate_texts(combined_preds)
return punct_text
def predict(self, input_slice):
"""
Passes the unpunctuated text to the model for punctuation.
"""
predictions, raw_outputs = self.model.predict([input_slice])
return predictions, raw_outputs
@staticmethod
def split_on_toks(text, length, overlap):
"""
Splits text into predefined slices of overlapping text with indexes (offsets)
that tie-back to original text.
This is done to bypass 512 token limit on transformer models by sequentially
feeding chunks of < 512 toks.
Example output:
[{...}, {"text": "...", 'start_idx': 31354, 'end_idx': 32648}, {...}]
"""
wrds = text.replace('\n', ' ').split(" ")
resp = []
lst_chunk_idx = 0
i = 0
while True:
# words in the chunk and the overlapping portion
wrds_len = wrds[(length * i):(length * (i + 1))]
wrds_ovlp = wrds[(length * (i + 1)):((length * (i + 1)) + overlap)]
wrds_split = wrds_len + wrds_ovlp
# Break loop if no more words
if not wrds_split:
break
wrds_str = " ".join(wrds_split)
nxt_chunk_start_idx = len(" ".join(wrds_len))
lst_char_idx = len(" ".join(wrds_split))
resp_obj = {
"text": wrds_str,
"start_idx": lst_chunk_idx,
"end_idx": lst_char_idx + lst_chunk_idx,
}
resp.append(resp_obj)
lst_chunk_idx += nxt_chunk_start_idx + 1
i += 1
logging.info(f"Sliced transcript into {len(resp)} slices.")
return resp
@staticmethod
def combine_results(full_text: str, text_slices):
"""
Given a full text and predictions of each slice combines predictions into a single text again.
Performs validataion wether text was combined correctly
"""
split_full_text = full_text.replace('\n', ' ').split(" ")
split_full_text = [i for i in split_full_text if i]
split_full_text_len = len(split_full_text)
output_text = []
index = 0
if len(text_slices[-1]) <= 3 and len(text_slices) > 1:
text_slices = text_slices[:-1]
for _slice in text_slices:
slice_wrds = len(_slice)
for ix, wrd in enumerate(_slice):
# print(index, "|", str(list(wrd.keys())[0]), "|", split_full_text[index])
if index == split_full_text_len:
break
if split_full_text[index] == str(list(wrd.keys())[0]) and \
ix <= slice_wrds - 3 and text_slices[-1] != _slice:
index += 1
pred_item_tuple = list(wrd.items())[0]
output_text.append(pred_item_tuple)
elif split_full_text[index] == str(list(wrd.keys())[0]) and text_slices[-1] == _slice:
index += 1
pred_item_tuple = list(wrd.items())[0]
output_text.append(pred_item_tuple)
assert [i[0] for i in output_text] == split_full_text
return output_text
@staticmethod
def punctuate_texts(full_pred: list):
"""
Given a list of Predictions from the model, applies the predictions to text,
thus punctuating it.
"""
punct_resp = ""
for i in full_pred:
word, label = i
if label[-1] == "U":
punct_wrd = word.capitalize()
else:
punct_wrd = word
if label[0] != "O":
punct_wrd += label[0]
punct_resp += punct_wrd + " "
punct_resp = punct_resp.strip()
# Append trailing period if doesnt exist.
if punct_resp[-1].isalnum():
punct_resp += "."
return punct_resp
if __name__ == "__main__":
start = time.time()
punct_model = RestorePuncts()
load_model = time.time()
print(f'Time to load model: {load_model - start}')
# read test file
# with open('en_lower.txt', 'r') as fp:
# # test_sample = fp.read()
# lines = fp.readlines()
with open('sample.vtt', 'r') as fp:
source_text = fp.read()
# captions = webvtt.read_buffer(StringIO(source_text))
captions = webvtt.read('sample.vtt')
source_sentences = [caption.text.replace('\r', '').replace('\n', ' ') for caption in captions]
# print(source_sentences)
sent = ' '.join(source_sentences)
punctuated = punct_model.punctuate(sent)
tokenised = sent_tokenize(punctuated)
# print(tokenised)
for i in range(len(tokenised)):
captions[i].text = tokenised[i]
# return captions.content
captions.save('my_captions.vtt')
end = time.time()
print(f'Time for run: {end - load_model}')
print(f'Total time: {end - start}')
|