whispertube_backend / transcription.py
uzi007's picture
Added GPU Models
052b4e7
raw
history blame
7.87 kB
import os
from abc import ABC, abstractmethod
from youtube_transcript_api import YouTubeTranscriptApi
from youtube_transcript_api.formatters import SRTFormatter, WebVTTFormatter
class Transcription(ABC):
def __init__(self, media_path, output_path, subtitle_format):
self.media_path = media_path
self.output_path = os.path.join(os.getcwd(), output_path)
self.filename = os.path.splitext(media_path)[0]
self.subtitle_format = subtitle_format
@abstractmethod
def generate_transcript(self):
pass
@abstractmethod
def save_transcript(self):
pass
class YouTubeTranscriptAPI(Transcription):
def __init__(self, url, media_path, output_path, subtitle_format='srt', transcript_language='en'):
super().__init__(media_path, output_path, subtitle_format)
self.url = url
self.video_id = url.split('v=')[1]
self.transcript_language = transcript_language
self.supported_subtitle_formats = ['srt', 'vtt']
assert(self.subtitle_format.lower() in self.supported_subtitle_formats)
def get_available_transcripts(self):
'''
Returns a dictionary of available transcripts & their info
'''
# Getting List of all Available Transcripts
transcript_list = YouTubeTranscriptApi.list_transcripts(self.video_id)
# Converting to Available Transcripts to Dictionary
transcripts_info = dict()
for transcript in transcript_list:
transcript_info = {
'language': transcript.language,
'is_generated': transcript.is_generated,
'is_translatable': transcript.is_translatable
}
transcripts_info[transcript.language_code] = transcript_info
return transcripts_info
def generate_transcript(self):
'''
Generates the transcript for the media file
'''
self.transcript = YouTubeTranscriptApi.get_transcript(self.video_id, languages=[self.transcript_language])
def save_transcript(self):
'''
Writes the transcript into file
'''
# Getting the Formatter
if self.subtitle_format == 'srt':
formatter = SRTFormatter()
elif self.subtitle_format == 'vtt':
formatter = WebVTTFormatter()
# Getting the Formatted Transcript
formatted_transcript = formatter.format_transcript(self.transcript)
# Writing the Formatted Transcript
file_path = f'{self.filename}.{self.subtitle_format}'
with open(file_path, 'w', encoding='utf-8') as transcript_file:
transcript_file.write(formatted_transcript)
return file_path
class Whisper(Transcription):
def __init__(self, media_path, output_path, subtitle_format, word_level):
super().__init__(media_path, output_path, subtitle_format)
self.word_level = word_level
self.supported_subtitle_formats = ['ass', 'srt', 'vtt']
assert(self.subtitle_format.lower() in self.supported_subtitle_formats)
class FasterWhisper(Whisper):
def __init__(self, model, media_path, output_path, subtitle_format='srt', word_level=True):
super().__init__(media_path, output_path, subtitle_format, word_level)
self.model = model
def generate_transcript(self):
'''
Generates the transcript for the media file
'''
all_text = []
all_segments = []
if self.word_level:
# Generating Word Level Transcript
segments, info = self.model.transcribe(self.media_path, word_timestamps=True)
# Converting to Dictionary
all_segments = []
for segment in segments:
for word in segment.words:
all_text.append(word.word)
segment_info = {
'text': word.word,
'start': round(word.start, 2),
'end': round(word.end, 2)
}
all_segments.append(segment_info)
else:
# Generating Word Level Transcript
segments, info = self.model.transcribe(self.media_path, beam_size=5)
# Converting to Dictionary
for segment in segments:
all_text.append(segment.text)
segment_info = {
'text': segment.text,
'start': round(segment.start, 2),
'end': round(segment.end, 2)
}
all_segments.append(segment_info)
# Setting Transcript Properties
self.text = ' '.join(all_text)
self.language = info.language
self.segments = all_segments
# Returning Transcript Properties as Dictionary
transcript_dict = {
'language': self.language,
'text': self.text,
'segments': self.segments
}
return transcript_dict
def save_transcript(self, transcript, output_file):
'''
Writes the transcript into file
'''
# TODO: Can't seem to find any built-in methods for writing transcript
pass
class StableWhisper(Whisper):
def __init__(self, model, media_path, output_path, subtitle_format='srt', word_level=True):
super().__init__(media_path, output_path, subtitle_format, word_level)
self.model = model
def generate_transcript(self):
'''
Generates the transcript for the media file
'''
# Generating Word Level Transcript
self.result = self.model.transcribe(self.media_path, word_timestamps=self.word_level)
# Converting to Dictionary
self.resultdict = self.result.to_dict()
# Formatting Dictionary
all_segments = []
if self.word_level:
all_segments = []
for segment in self.resultdict['segments']:
for word in segment['words']:
segment_info = {
'text': word['word'],
'start': round(word['start'], 2),
'end': round(word['end'], 2)
}
all_segments.append(segment_info)
else:
for segment in self.resultdict['segments']:
segment_info = {
'text': segment['text'],
'start': round(segment['start'], 2),
'end': round(segment['end'], 2)
}
all_segments.append(segment_info)
# Setting Transcript Properties
self.text = self.resultdict['text']
self.language = self.resultdict['language']
self.segments = all_segments
# Returning Transcript Properties as Dictionary
transcript_dict = {
'language': self.language,
'text': self.text,
'segments': self.segments
}
return transcript_dict
def save_transcript(self):
'''
Writes the transcript into file
'''
# Writing to TXT file in UTF-8 format
file_path = os.path.join(self.output_path, f'{self.filename}.txt')
with open(file_path, 'w', encoding='utf-8') as file:
file.write(self.text)
return file_path
def save_subtitles(self):
'''
Writes the subtitles into file
'''
# Writing according to the Format
file_path = os.path.join(self.output_path, f'{self.filename}.{self.subtitle_format}')
if self.subtitle_format == 'ass':
self.result.to_ass(file_path, segment_level=True, word_level=self.word_level)
elif self.subtitle_format in ['srt', 'vtt']:
self.result.to_srt_vtt(file_path, segment_level=True, word_level=self.word_level)
return file_path