crypto-code commited on
Commit
686c4ae
1 Parent(s): d9cb0bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -20,6 +20,7 @@ import torchvision.transforms as transforms
20
  import av
21
  import subprocess
22
  import librosa
 
23
 
24
  args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
25
  "mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
@@ -33,8 +34,6 @@ class dotdict(dict):
33
 
34
  args = dotdict(args)
35
 
36
- generated_audio_files = []
37
-
38
  llama_type = args.llama_type
39
  llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
40
  llama_tokenzier_path = args.llama_dir
@@ -118,7 +117,6 @@ def parse_text(text, image_path, video_path, audio_path):
118
 
119
 
120
  def save_audio_to_local(audio, sec):
121
- global generated_audio_files
122
  if not os.path.exists('temp'):
123
  os.mkdir('temp')
124
  filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.wav')
@@ -126,7 +124,6 @@ def save_audio_to_local(audio, sec):
126
  scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
127
  else:
128
  scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
129
- generated_audio_files.append(filename)
130
  return filename
131
 
132
 
@@ -166,8 +163,6 @@ def reset_dialog():
166
 
167
 
168
  def reset_state():
169
- global generated_audio_files
170
- generated_audio_files = []
171
  return None, None, None, None, [], [], []
172
 
173
 
@@ -214,6 +209,12 @@ def get_video_length(filename):
214
  def get_audio_length(filename):
215
  return int(round(librosa.get_duration(path=filename)))
216
 
 
 
 
 
 
 
217
 
218
  def predict(
219
  prompt_input,
@@ -226,7 +227,6 @@ def predict(
226
  history,
227
  modality_cache,
228
  audio_length_in_s):
229
- global generated_audio_files
230
  prompts = [llama.format_prompt(prompt_input)]
231
  prompts = [model.tokenizer(x).input_ids for x in prompts]
232
  print(image_path, audio_path, video_path)
@@ -244,11 +244,11 @@ def predict(
244
  container = av.open(video_path)
245
  indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
246
  video = read_video_pyav(container=container, indices=indices)
247
-
248
- if len(generated_audio_files) != 0:
249
- audio_length_in_s = get_audio_length(generated_audio_files[-1])
250
  sample_rate = 24000
251
- waveform, sr = torchaudio.load(generated_audio_files[-1])
252
  if sample_rate != sr:
253
  waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
254
  audio = torch.mean(waveform, 0)
@@ -259,7 +259,6 @@ def predict(
259
  print(f"Video Length: {audio_length_in_s}")
260
  if audio_path is not None:
261
  audio_length_in_s = get_audio_length(audio_path)
262
- generated_audio_files.append(audio_path)
263
  print(f"Audio Length: {audio_length_in_s}")
264
 
265
  print(image, video, audio)
 
20
  import av
21
  import subprocess
22
  import librosa
23
+ import re
24
 
25
  args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
26
  "mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
 
34
 
35
  args = dotdict(args)
36
 
 
 
37
  llama_type = args.llama_type
38
  llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
39
  llama_tokenzier_path = args.llama_dir
 
117
 
118
 
119
  def save_audio_to_local(audio, sec):
 
120
  if not os.path.exists('temp'):
121
  os.mkdir('temp')
122
  filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.wav')
 
124
  scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
125
  else:
126
  scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
 
127
  return filename
128
 
129
 
 
163
 
164
 
165
  def reset_state():
 
 
166
  return None, None, None, None, [], [], []
167
 
168
 
 
209
  def get_audio_length(filename):
210
  return int(round(librosa.get_duration(path=filename)))
211
 
212
+ def get_last_audio():
213
+ for hist in history[::-1]:
214
+ print(hist)
215
+ if "<audio controls playsinline>" in hist[1]:
216
+ return re.search('<audio controls playsinline><source src=\"\.\/file=(.*)\" type="audio\/wav"><\/audio>', hist[1]).group(1)
217
+ return None
218
 
219
  def predict(
220
  prompt_input,
 
227
  history,
228
  modality_cache,
229
  audio_length_in_s):
 
230
  prompts = [llama.format_prompt(prompt_input)]
231
  prompts = [model.tokenizer(x).input_ids for x in prompts]
232
  print(image_path, audio_path, video_path)
 
244
  container = av.open(video_path)
245
  indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
246
  video = read_video_pyav(container=container, indices=indices)
247
+ generated_audio_file = get_last_audio()
248
+ if generated_audio_file is not None:
249
+ audio_length_in_s = get_audio_length(generated_audio_file)
250
  sample_rate = 24000
251
+ waveform, sr = torchaudio.load(generated_audio_file)
252
  if sample_rate != sr:
253
  waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
254
  audio = torch.mean(waveform, 0)
 
259
  print(f"Video Length: {audio_length_in_s}")
260
  if audio_path is not None:
261
  audio_length_in_s = get_audio_length(audio_path)
 
262
  print(f"Audio Length: {audio_length_in_s}")
263
 
264
  print(image, video, audio)