File size: 4,405 Bytes
1a68fc3 3fadc6e 05a2178 295de00 05a2178 15219b4 05a2178 15219b4 3fadc6e 05a2178 6a308c6 05a2178 48d8572 1a68fc3 05a2178 1a68fc3 05a2178 3fadc6e 6a308c6 3fadc6e 48d8572 1a68fc3 3fadc6e 15219b4 1a68fc3 3fadc6e 48d8572 6a308c6 1a68fc3 3fadc6e 295de00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import textwrap
import unicodedata
import re
import zlib
from typing import Iterator, TextIO
import tqdm
import urllib3
def exact_div(x, y):
assert x % y == 0
return x // y
def str2bool(string):
str2val = {"True": True, "False": False}
if string in str2val:
return str2val[string]
else:
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
def optional_int(string):
return None if string == "None" else int(string)
def optional_float(string):
return None if string == "None" else float(string)
def compression_ratio(text) -> float:
return len(text) / len(zlib.compress(text.encode("utf-8")))
def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
def write_txt(transcript: Iterator[dict], file: TextIO):
for segment in transcript:
print(segment['text'].strip(), file=file, flush=True)
def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
print("WEBVTT\n", file=file)
for segment in transcript:
text = process_text(segment['text'], maxLineWidth).replace('-->', '->')
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{text}\n",
file=file,
flush=True,
)
def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
"""
Write a transcript to a file in SRT format.
Example usage:
from pathlib import Path
from whisper.utils import write_srt
result = transcribe(model, audio_path, temperature=temperature, **args)
# save SRT
audio_basename = Path(audio_path).stem
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
"""
for i, segment in enumerate(transcript, start=1):
text = process_text(segment['text'].strip(), maxLineWidth).replace('-->', '->')
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
f"{text}\n",
file=file,
flush=True,
)
def process_text(text: str, maxLineWidth=None):
if (maxLineWidth is None or maxLineWidth < 0):
return text
lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
return '\n'.join(lines)
def slugify(value, allow_unicode=False):
"""
Taken from https://github.com/django/django/blob/master/django/utils/text.py
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
dashes to single dashes. Remove characters that aren't alphanumerics,
underscores, or hyphens. Convert to lowercase. Also strip leading and
trailing whitespace, dashes, and underscores.
"""
value = str(value)
if allow_unicode:
value = unicodedata.normalize('NFKC', value)
else:
value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
value = re.sub(r'[^\w\s-]', '', value.lower())
return re.sub(r'[-\s]+', '-', value).strip('-_')
def download_file(url: str, destination: str):
with urllib3.request.urlopen(url) as source, open(destination, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer)) |