Raphael commited on
Commit
7d79fa6
1 Parent(s): 4581d71

Improve translation and subtitles sync

Browse files

Signed-off-by: Raphael <oOraph@users.noreply.github.com>

Files changed (1) hide show
  1. app.py +101 -41
app.py CHANGED
@@ -10,6 +10,7 @@ import gradio as gr
10
  import moviepy.editor as mp
11
  import numpy as np
12
  import pysrt
 
13
  import torch
14
  from transformers import pipeline
15
  import yt_dlp
@@ -22,9 +23,10 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(level
22
  LOG = logging.getLogger(__name__)
23
  CLIP_SECONDS = 20
24
  SLICES = 4
25
- SLICE_DURATION = CLIP_SECONDS / SLICES
26
  # At most 6 mins
27
  MAX_CHUNKS = 45
 
28
 
29
  asr_kwargs = {
30
  "task": "automatic-speech-recognition",
@@ -118,7 +120,7 @@ def process_video(basedir: str, duration, translate: bool):
118
  subs = translation(transcriptions, translate)
119
  srt_file = build_srt_clips(subs, basedir)
120
  summary = summarize(transcriptions, translate)
121
- return srt_file, ' '.join(subs).strip(), summary
122
 
123
 
124
  def transcription(audio_dir: str, duration):
@@ -141,74 +143,131 @@ def transcription(audio_dir: str, duration):
141
  t = asr(d, max_new_tokens=10000)
142
  transcriptions.extend(t)
143
 
144
- transcriptions = [t['text'] for t in transcriptions]
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  elapsed = time.time() - start
146
  LOG.info("Transcription done, elapsed %.2f seconds", elapsed)
147
- return transcriptions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
 
150
  def translation(transcriptions, translate):
 
151
  if translate:
152
  LOG.info("Performing translation")
153
  start = time.time()
154
- translations = translator(transcriptions)
155
- translations = [t['translation_text'] for t in translations]
 
 
 
156
  elapsed = time.time() - start
157
  LOG.info("Translation done, elapsed %.2f seconds", elapsed)
 
158
  else:
159
- translations = transcriptions
160
- return translations
161
 
162
 
163
  def summarize(transcriptions, translate):
164
  LOG.info("Generating video summary")
165
- whole_text = ' '.join(transcriptions).strip()
166
- word_count = len(whole_text.split())
167
  summary = summarizer(whole_text)
168
  # min_length=word_count // 4 + 1,
169
  # max_length=word_count // 2 + 1)
170
- summary = translation([summary[0]['summary_text']], translate)[0]
171
- return summary
172
 
173
 
174
- def subs_to_timed_segments(subtitles: list[str]):
175
- LOG.info("Building srt segments")
176
- all_chunks = []
177
  for sub in subtitles:
178
- chunks = np.array_split(sub.split(' '), SLICES)
179
- all_chunks.extend(chunks)
180
-
181
- subs = []
182
- for c in all_chunks:
183
- c = ' '.join(c)
184
- subs.append(c)
185
-
186
- segments = []
187
- for i, c in enumerate(subs):
188
- segments.append({
189
- 'text': c.strip(),
190
- 'start': i * SLICE_DURATION,
191
- 'end': (i + 1) * SLICE_DURATION
192
- })
193
-
194
- return segments
195
 
196
 
197
- def build_srt_clips(subs, basedir):
198
 
199
  LOG.info("Generating subtitles")
200
- segments = subs_to_timed_segments(subs)
201
 
202
  LOG.info("Building srt clips")
203
- max_text_len = 30
204
  subtitles = pysrt.SubRipFile()
205
- first = True
206
  for segment in segments:
207
- start = segment['start'] * 1000
208
- if first:
209
- start += 3000
210
- first = False
211
- end = segment['end'] * 1000
212
  text = segment['text']
213
  text = text.strip()
214
  if len(text) < max_text_len:
@@ -250,4 +309,5 @@ iface = gr.Interface(
250
  gr.Text(label="Full transcription")
251
  ])
252
 
 
253
  iface.launch()
 
10
  import moviepy.editor as mp
11
  import numpy as np
12
  import pysrt
13
+ import re
14
  import torch
15
  from transformers import pipeline
16
  import yt_dlp
 
23
  LOG = logging.getLogger(__name__)
24
  CLIP_SECONDS = 20
25
  SLICES = 4
26
+ # SLICE_DURATION = CLIP_SECONDS / SLICES
27
  # At most 6 mins
28
  MAX_CHUNKS = 45
29
+ SENTENCE_SPLIT = re.compile(r'([^.?!]*[.?!]+)([^.?!].*|$)')
30
 
31
  asr_kwargs = {
32
  "task": "automatic-speech-recognition",
 
120
  subs = translation(transcriptions, translate)
121
  srt_file = build_srt_clips(subs, basedir)
122
  summary = summarize(transcriptions, translate)
123
+ return srt_file, ' '.join([s['text'].strip() for s in subs]).strip(), summary
124
 
125
 
126
  def transcription(audio_dir: str, duration):
 
143
  t = asr(d, max_new_tokens=10000)
144
  transcriptions.extend(t)
145
 
146
+ transcriptions = [
147
+ {
148
+ 'text': t['text'].strip(),
149
+ 'start': i * CLIP_SECONDS * 1000,
150
+ 'end': (i + 1) * CLIP_SECONDS * 1000
151
+ } for i, t in enumerate(transcriptions)
152
+ ]
153
+
154
+ if transcriptions:
155
+ transcriptions[0]['start'] += 2500
156
+
157
+ # Will improve the translation
158
+ segments = segments_on_sentence_boundaries(transcriptions)
159
+
160
  elapsed = time.time() - start
161
  LOG.info("Transcription done, elapsed %.2f seconds", elapsed)
162
+ return segments
163
+
164
+
165
+ def segments_on_sentence_boundaries(segments):
166
+
167
+ LOG.info("Segmenting along sentence boundaries for better translations")
168
+
169
+ new_segments = []
170
+ i = 0
171
+ while i < len(segments):
172
+ s = segments[i]
173
+ text = s['text'].strip()
174
+ if not text:
175
+ i += 1
176
+ continue
177
+
178
+ if i == len(segments)-1:
179
+ new_segments.append(s)
180
+ break
181
+
182
+ next_s = segments[i+1]
183
+
184
+ next_text = next_s['text'].strip()
185
+ if not next_text or (text[-1] in ['.', '?', '!']):
186
+ new_segments.append(s)
187
+ i += 1
188
+ continue
189
+
190
+ m = SENTENCE_SPLIT.match(next_s['text'].strip())
191
+ if not m:
192
+ LOG.warning("Bad pattern matching on segment [%s], "
193
+ "this should not be possible", next_s['text'])
194
+ s['end'] = next_s['end']
195
+ s['text'] = '{} {}'.format(s['text'].strip(), next_s['text'].strip())
196
+ new_segments.append(s)
197
+ i += 2
198
+ else:
199
+ before = m.group(1)
200
+ after = m.group(2)
201
+ next_segment_duration = next_s['end'] - next_s['start']
202
+ ratio = len(before) / len(next_text)
203
+ add_time = int(next_segment_duration * ratio)
204
+ s['end'] = s['end'] + add_time
205
+ s['text'] = '{} {}'.format(text, before)
206
+ next_s['start'] = next_s['start'] + add_time
207
+ next_s['text'] = after.strip()
208
+ new_segments.append(s)
209
+ i += 1
210
+
211
+ return new_segments
212
 
213
 
214
  def translation(transcriptions, translate):
215
+ translations_d = []
216
  if translate:
217
  LOG.info("Performing translation")
218
  start = time.time()
219
+ translations = translator([t['text'] for t in transcriptions])
220
+ for i, t in enumerate(transcriptions):
221
+ tsl = t.copy()
222
+ tsl['text'] = translations[i]['translation_text'].strip()
223
+ translations_d.append(tsl)
224
  elapsed = time.time() - start
225
  LOG.info("Translation done, elapsed %.2f seconds", elapsed)
226
+ LOG.info('Translations %s', translations_d)
227
  else:
228
+ translations_d = transcriptions
229
+ return translations_d
230
 
231
 
232
  def summarize(transcriptions, translate):
233
  LOG.info("Generating video summary")
234
+ whole_text = ' '.join([t['text'].strip() for t in transcriptions])
235
+ # word_count = len(whole_text.split())
236
  summary = summarizer(whole_text)
237
  # min_length=word_count // 4 + 1,
238
  # max_length=word_count // 2 + 1)
239
+ summary = translation([{'text': summary[0]['summary_text']}], translate)[0]
240
+ return summary['text']
241
 
242
 
243
+ def segment_slices(subtitles: list[str]):
244
+ LOG.info("Building srt segments slices")
245
+ slices = []
246
  for sub in subtitles:
247
+ chunks = np.array_split(sub['text'].split(' '), SLICES)
248
+ start = sub['start']
249
+ duration = sub['end'] - start
250
+ for i in range(0, SLICES):
251
+ s = {
252
+ 'text': ' '.join(chunks[i]),
253
+ 'start': start + i * duration / SLICES,
254
+ 'end': start + (i+1) * duration / SLICES
255
+ }
256
+ slices.append(s)
257
+ return slices
 
 
 
 
 
 
258
 
259
 
260
+ def build_srt_clips(segments, basedir):
261
 
262
  LOG.info("Generating subtitles")
263
+ segments = segment_slices(segments)
264
 
265
  LOG.info("Building srt clips")
266
+ max_text_len = 45
267
  subtitles = pysrt.SubRipFile()
 
268
  for segment in segments:
269
+ start = segment['start']
270
+ end = segment['end']
 
 
 
271
  text = segment['text']
272
  text = text.strip()
273
  if len(text) < max_text_len:
 
309
  gr.Text(label="Full transcription")
310
  ])
311
 
312
+ # iface.launch(server_name="0.0.0.0", server_port=6443)
313
  iface.launch()