crypto-code commited on
Commit
e77fc2d
1 Parent(s): 686c4ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -20,7 +20,6 @@ import torchvision.transforms as transforms
20
  import av
21
  import subprocess
22
  import librosa
23
- import re
24
 
25
  args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
26
  "mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
@@ -34,6 +33,8 @@ class dotdict(dict):
34
 
35
  args = dotdict(args)
36
 
 
 
37
  llama_type = args.llama_type
38
  llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
39
  llama_tokenzier_path = args.llama_dir
@@ -117,6 +118,7 @@ def parse_text(text, image_path, video_path, audio_path):
117
 
118
 
119
  def save_audio_to_local(audio, sec):
 
120
  if not os.path.exists('temp'):
121
  os.mkdir('temp')
122
  filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.wav')
@@ -124,6 +126,7 @@ def save_audio_to_local(audio, sec):
124
  scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
125
  else:
126
  scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
 
127
  return filename
128
 
129
 
@@ -159,10 +162,14 @@ def reset_user_input():
159
 
160
 
161
  def reset_dialog():
 
 
162
  return [], []
163
 
164
 
165
  def reset_state():
 
 
166
  return None, None, None, None, [], [], []
167
 
168
 
@@ -209,12 +216,6 @@ def get_video_length(filename):
209
  def get_audio_length(filename):
210
  return int(round(librosa.get_duration(path=filename)))
211
 
212
- def get_last_audio():
213
- for hist in history[::-1]:
214
- print(hist)
215
- if "<audio controls playsinline>" in hist[1]:
216
- return re.search('<audio controls playsinline><source src=\"\.\/file=(.*)\" type="audio\/wav"><\/audio>', hist[1]).group(1)
217
- return None
218
 
219
  def predict(
220
  prompt_input,
@@ -227,6 +228,7 @@ def predict(
227
  history,
228
  modality_cache,
229
  audio_length_in_s):
 
230
  prompts = [llama.format_prompt(prompt_input)]
231
  prompts = [model.tokenizer(x).input_ids for x in prompts]
232
  print(image_path, audio_path, video_path)
@@ -244,11 +246,11 @@ def predict(
244
  container = av.open(video_path)
245
  indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
246
  video = read_video_pyav(container=container, indices=indices)
247
- generated_audio_file = get_last_audio()
248
- if generated_audio_file is not None:
249
- audio_length_in_s = get_audio_length(generated_audio_file)
250
  sample_rate = 24000
251
- waveform, sr = torchaudio.load(generated_audio_file)
252
  if sample_rate != sr:
253
  waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
254
  audio = torch.mean(waveform, 0)
@@ -259,6 +261,7 @@ def predict(
259
  print(f"Video Length: {audio_length_in_s}")
260
  if audio_path is not None:
261
  audio_length_in_s = get_audio_length(audio_path)
 
262
  print(f"Audio Length: {audio_length_in_s}")
263
 
264
  print(image, video, audio)
@@ -350,4 +353,4 @@ with gr.Blocks() as demo:
350
  ], show_progress=True)
351
 
352
  if __name__ == "__main__":
353
- demo.launch()
 
20
  import av
21
  import subprocess
22
  import librosa
 
23
 
24
  args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
25
  "mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
 
33
 
34
  args = dotdict(args)
35
 
36
+ generated_audio_files = []
37
+
38
  llama_type = args.llama_type
39
  llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
40
  llama_tokenzier_path = args.llama_dir
 
118
 
119
 
120
  def save_audio_to_local(audio, sec):
121
+ global generated_audio_files
122
  if not os.path.exists('temp'):
123
  os.mkdir('temp')
124
  filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.wav')
 
126
  scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
127
  else:
128
  scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
129
+ generated_audio_files.append(filename)
130
  return filename
131
 
132
 
 
162
 
163
 
164
  def reset_dialog():
165
+ global generated_audio_files
166
+ generated_audio_files = []
167
  return [], []
168
 
169
 
170
  def reset_state():
171
+ global generated_audio_files
172
+ generated_audio_files = []
173
  return None, None, None, None, [], [], []
174
 
175
 
 
216
  def get_audio_length(filename):
217
  return int(round(librosa.get_duration(path=filename)))
218
 
 
 
 
 
 
 
219
 
220
  def predict(
221
  prompt_input,
 
228
  history,
229
  modality_cache,
230
  audio_length_in_s):
231
+ global generated_audio_files
232
  prompts = [llama.format_prompt(prompt_input)]
233
  prompts = [model.tokenizer(x).input_ids for x in prompts]
234
  print(image_path, audio_path, video_path)
 
246
  container = av.open(video_path)
247
  indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
248
  video = read_video_pyav(container=container, indices=indices)
249
+
250
+ if len(generated_audio_files) != 0:
251
+ audio_length_in_s = get_audio_length(generated_audio_files[-1])
252
  sample_rate = 24000
253
+ waveform, sr = torchaudio.load(generated_audio_files[-1])
254
  if sample_rate != sr:
255
  waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
256
  audio = torch.mean(waveform, 0)
 
261
  print(f"Video Length: {audio_length_in_s}")
262
  if audio_path is not None:
263
  audio_length_in_s = get_audio_length(audio_path)
264
+ generated_audio_files.append(audio_path)
265
  print(f"Audio Length: {audio_length_in_s}")
266
 
267
  print(image, video, audio)
 
353
  ], show_progress=True)
354
 
355
  if __name__ == "__main__":
356
+ demo.launch()