nelikCode commited on
Commit
1044a67
·
verified ·
1 Parent(s): 83a90e2

chore: code refract

Browse files
Files changed (1) hide show
  1. app.py +57 -15
app.py CHANGED
@@ -6,14 +6,19 @@ import whisper
6
  from moviepy.editor import (
7
  AudioFileClip,
8
  ColorClip,
9
- CompositeVideoClip,
10
  VideoFileClip,
11
  concatenate_videoclips,
12
  )
13
- from moviepy.video.VideoClip import TextClip
14
 
15
 
16
- def generate_srt_file(transcription_result, srt_file_path, lag=0):
 
 
 
 
 
 
 
17
  with open(srt_file_path, "w") as file:
18
  for i, segment in enumerate(transcription_result["segments"], start=1):
19
  # Adjusting times for lag
@@ -28,7 +33,17 @@ def generate_srt_file(transcription_result, srt_file_path, lag=0):
28
  file.write(f"{i}\n{start_srt} --> {end_srt}\n{text}\n\n")
29
 
30
 
31
- def get_srt_filename(video_path, audio_path):
 
 
 
 
 
 
 
 
 
 
32
  if video_path is not None:
33
  return os.path.splitext(os.path.basename(video_path))[0] + ".srt"
34
  else:
@@ -36,14 +51,33 @@ def get_srt_filename(video_path, audio_path):
36
 
37
 
38
  def generate_video(
39
- audio_path, video_path, input, language, lag, progress=gr.Progress(track_tqdm=True)
40
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  if audio_path is None and video_path is None:
42
- raise ValueError("Please upload an audio or video file.")
43
  if input == "Video" and video_path is None:
44
- raise ValueError("Please upload a video file.")
45
  if input == "Audio" and audio_path is None:
46
- raise ValueError("Please upload an audio file.")
47
  progress(0.0, "Checking input...")
48
  if input == "Video":
49
  progress(0.0, "Extracting audio from video...")
@@ -55,7 +89,7 @@ def generate_video(
55
 
56
  # Transcribe audio
57
  progress(0.1, "Transcribing audio...")
58
- result = model.transcribe(audio_path, language=language)
59
  progress(0.30, "Audio transcribed!")
60
 
61
  # Generate SRT file
@@ -72,7 +106,6 @@ def generate_video(
72
  else:
73
  # we simply extend the original video with a black screen at the end of duration lag
74
  video = VideoFileClip(video_path)
75
- fps = video.fps
76
  black_screen = ColorClip(
77
  size=video.size, color=(0, 0, 0), duration=lag
78
  ).set_fps(1)
@@ -96,7 +129,17 @@ def generate_video(
96
  return output_video_path, srt_file_path
97
 
98
 
99
- def download_srt(audio_input, video_input):
 
 
 
 
 
 
 
 
 
 
100
  srt_file_path = get_srt_filename(video_input, audio_input)
101
  if os.path.exists(srt_file_path):
102
  return srt_file_path
@@ -106,9 +149,8 @@ def download_srt(audio_input, video_input):
106
 
107
  if __name__ == "__main__":
108
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
109
- model = whisper.load_model("base", device=DEVICE)
110
 
111
- # Gradio Blocks implementation
112
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
113
  gr.Markdown(
114
  """
@@ -176,4 +218,4 @@ if __name__ == "__main__":
176
  outputs=srt_file_output,
177
  )
178
 
179
- demo.launch()
 
6
  from moviepy.editor import (
7
  AudioFileClip,
8
  ColorClip,
 
9
  VideoFileClip,
10
  concatenate_videoclips,
11
  )
 
12
 
13
 
14
+ def generate_srt_file(transcription_result: dict, srt_file_path: str, lag=0) -> None:
15
+ """
16
+ Write and save an SRT file from the transcription result.
17
+
18
+ Args:
19
+ transcription_result: The transcription result from Whisper model.
20
+ srt_file_path: The path to save the SRT file.
21
+ """
22
  with open(srt_file_path, "w") as file:
23
  for i, segment in enumerate(transcription_result["segments"], start=1):
24
  # Adjusting times for lag
 
33
  file.write(f"{i}\n{start_srt} --> {end_srt}\n{text}\n\n")
34
 
35
 
36
+ def get_srt_filename(video_path: str, audio_path: str = None) -> str:
37
+ """
38
+ Get the SRT filename based on the input video or audio file.
39
+
40
+ Args:
41
+ video_path: The path to the video file.
42
+ audio_path: The path to the audio file.
43
+
44
+ Returns:
45
+ The SRT filename.
46
+ """
47
  if video_path is not None:
48
  return os.path.splitext(os.path.basename(video_path))[0] + ".srt"
49
  else:
 
51
 
52
 
53
  def generate_video(
54
+ audio_path: str,
55
+ video_path: str,
56
+ input: str,
57
+ language: str,
58
+ lag: int,
59
+ progress: gr.Progress = gr.Progress(track_tqdm=True),
60
+ ) -> tuple[str, str]:
61
+ """
62
+ Generate a subtitled video from the input audio or video file.
63
+
64
+ Args:
65
+ audio_path: The path to the audio file.
66
+ video_path: The path to the video file.
67
+ input: The type of input file (audio or video).
68
+ language: The language code for transcription.
69
+ lag: The lag time in seconds to delay the transcription.
70
+ progress: The progress bar to show the progress of the task.
71
+
72
+ Returns:
73
+ The path to the generated video file and the SRT file.
74
+ """
75
  if audio_path is None and video_path is None:
76
+ raise gr.Error("Please upload an audio or video file.")
77
  if input == "Video" and video_path is None:
78
+ raise gr.Error("Please upload a video file.")
79
  if input == "Audio" and audio_path is None:
80
+ raise gr.Error("Please upload an audio file.")
81
  progress(0.0, "Checking input...")
82
  if input == "Video":
83
  progress(0.0, "Extracting audio from video...")
 
89
 
90
  # Transcribe audio
91
  progress(0.1, "Transcribing audio...")
92
+ result = MODEL.transcribe(audio_path, language=language)
93
  progress(0.30, "Audio transcribed!")
94
 
95
  # Generate SRT file
 
106
  else:
107
  # we simply extend the original video with a black screen at the end of duration lag
108
  video = VideoFileClip(video_path)
 
109
  black_screen = ColorClip(
110
  size=video.size, color=(0, 0, 0), duration=lag
111
  ).set_fps(1)
 
129
  return output_video_path, srt_file_path
130
 
131
 
132
+ def download_srt(audio_input: str, video_input: str) -> str:
133
+ """
134
+ Download the SRT file based on the input audio or video file.
135
+
136
+ Args:
137
+ audio_input: The path to the audio file.
138
+ video_input: The path to the video file.
139
+
140
+ Returns:
141
+ The path to the downloaded SRT file.
142
+ """
143
  srt_file_path = get_srt_filename(video_input, audio_input)
144
  if os.path.exists(srt_file_path):
145
  return srt_file_path
 
149
 
150
  if __name__ == "__main__":
151
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
152
+ MODEL = whisper.load_model("base", device=DEVICE)
153
 
 
154
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
155
  gr.Markdown(
156
  """
 
218
  outputs=srt_file_output,
219
  )
220
 
221
+ demo.launch()