crypto-code commited on
Commit
ead7a82
1 Parent(s): 15a96ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -16
app.py CHANGED
@@ -20,6 +20,7 @@ import torchvision.transforms as transforms
20
  import av
21
  import subprocess
22
  import librosa
 
23
 
24
  args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
25
  "mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
@@ -33,7 +34,7 @@ class dotdict(dict):
33
 
34
  args = dotdict(args)
35
 
36
- generated_audio_files = []
37
 
38
  llama_type = args.llama_type
39
  llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
@@ -117,7 +118,7 @@ def parse_text(text, image_path, video_path, audio_path):
117
  return text, outputs
118
 
119
 
120
- def save_audio_to_local(audio, sec):
121
  global generated_audio_files
122
  if not os.path.exists('temp'):
123
  os.mkdir('temp')
@@ -126,11 +127,11 @@ def save_audio_to_local(audio, sec):
126
  scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
127
  else:
128
  scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
129
- generated_audio_files.append(filename)
130
  return filename
131
 
132
 
133
- def parse_reponse(model_outputs, audio_length_in_s):
134
  response = ''
135
  text_outputs = []
136
  for output_i, p in enumerate(model_outputs):
@@ -146,7 +147,7 @@ def parse_reponse(model_outputs, audio_length_in_s):
146
  response += '<br>'
147
  _temp_output += m.replace(' '.join([f'[AUD{i}]' for i in range(8)]), '')
148
  else:
149
- filename = save_audio_to_local(m, audio_length_in_s)
150
  print(filename)
151
  _temp_output = f'<Audio>{filename}</Audio> ' + _temp_output
152
  response += f'<audio controls playsinline><source src="./file={filename}" type="audio/wav"></audio>'
@@ -161,15 +162,15 @@ def reset_user_input():
161
  return gr.update(value='')
162
 
163
 
164
- def reset_dialog():
165
  global generated_audio_files
166
- generated_audio_files = []
167
  return [], []
168
 
169
 
170
- def reset_state():
171
  global generated_audio_files
172
- generated_audio_files = []
173
  return None, None, None, None, [], [], []
174
 
175
 
@@ -218,6 +219,7 @@ def get_audio_length(filename):
218
 
219
 
220
  def predict(
 
221
  prompt_input,
222
  image_path,
223
  audio_path,
@@ -247,28 +249,30 @@ def predict(
247
  indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
248
  video = read_video_pyav(container=container, indices=indices)
249
 
250
- if len(generated_audio_files) != 0:
251
- audio_length_in_s = get_audio_length(generated_audio_files[-1])
252
  sample_rate = 24000
253
- waveform, sr = torchaudio.load(generated_audio_files[-1])
254
  if sample_rate != sr:
255
  waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
256
  audio = torch.mean(waveform, 0)
257
  audio_length_in_s = int(len(audio)//sample_rate)
258
  print(f"Audio Length: {audio_length_in_s}")
 
 
259
  if video_path is not None:
260
  audio_length_in_s = get_video_length(video_path)
261
  print(f"Video Length: {audio_length_in_s}")
262
  if audio_path is not None:
263
  audio_length_in_s = get_audio_length(audio_path)
264
- generated_audio_files.append(audio_path)
265
  print(f"Audio Length: {audio_length_in_s}")
266
 
267
  print(image, video, audio)
268
  response = model.generate(prompts, audio, image, video, 200, temperature, top_p,
269
  audio_length_in_s=audio_length_in_s)
270
  print(response)
271
- response_chat, response_outputs = parse_reponse(response, audio_length_in_s)
272
  print('text_outputs: ', response_outputs)
273
  user_chat, user_outputs = parse_text(prompt_input, image_path, video_path, audio_path)
274
  chatbot.append((user_chat, response_chat))
@@ -319,9 +323,11 @@ with gr.Blocks() as demo:
319
 
320
  history = gr.State([])
321
  modality_cache = gr.State([])
 
322
 
323
  submitBtn.click(
324
  predict, [
 
325
  user_input,
326
  image_path,
327
  audio_path,
@@ -343,8 +349,8 @@ with gr.Blocks() as demo:
343
  show_progress=True
344
  )
345
 
346
- submitBtn.click(reset_user_input, [], [user_input])
347
- emptyBtn.click(reset_state, outputs=[
348
  image_path,
349
  audio_path,
350
  video_path,
 
20
  import av
21
  import subprocess
22
  import librosa
23
+ import uuid
24
 
25
  args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
26
  "mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
 
34
 
35
  args = dotdict(args)
36
 
37
+ generated_audio_files = {}
38
 
39
  llama_type = args.llama_type
40
  llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
 
118
  return text, outputs
119
 
120
 
121
+ def save_audio_to_local(uid, audio, sec):
122
  global generated_audio_files
123
  if not os.path.exists('temp'):
124
  os.mkdir('temp')
 
127
  scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
128
  else:
129
  scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
130
+ generated_audio_files[uid].append(filename)
131
  return filename
132
 
133
 
134
+ def parse_reponse(uid, model_outputs, audio_length_in_s):
135
  response = ''
136
  text_outputs = []
137
  for output_i, p in enumerate(model_outputs):
 
147
  response += '<br>'
148
  _temp_output += m.replace(' '.join([f'[AUD{i}]' for i in range(8)]), '')
149
  else:
150
+ filename = save_audio_to_local(uid, m, audio_length_in_s)
151
  print(filename)
152
  _temp_output = f'<Audio>{filename}</Audio> ' + _temp_output
153
  response += f'<audio controls playsinline><source src="./file={filename}" type="audio/wav"></audio>'
 
162
  return gr.update(value='')
163
 
164
 
165
+ def reset_dialog(uid):
166
  global generated_audio_files
167
+ generated_audio_files[uid] = []
168
  return [], []
169
 
170
 
171
+ def reset_state(uid):
172
  global generated_audio_files
173
+ generated_audio_files[uid] = []
174
  return None, None, None, None, [], [], []
175
 
176
 
 
219
 
220
 
221
  def predict(
222
+ uid,
223
  prompt_input,
224
  image_path,
225
  audio_path,
 
249
  indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
250
  video = read_video_pyav(container=container, indices=indices)
251
 
252
+ if uid in generated_audio_files and len(generated_audio_files[uid]) != 0:
253
+ audio_length_in_s = get_audio_length(generated_audio_files[uid][-1])
254
  sample_rate = 24000
255
+ waveform, sr = torchaudio.load(generated_audio_files[uid][-1])
256
  if sample_rate != sr:
257
  waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
258
  audio = torch.mean(waveform, 0)
259
  audio_length_in_s = int(len(audio)//sample_rate)
260
  print(f"Audio Length: {audio_length_in_s}")
261
+ else:
262
+ generated_audio_files[uid] = []
263
  if video_path is not None:
264
  audio_length_in_s = get_video_length(video_path)
265
  print(f"Video Length: {audio_length_in_s}")
266
  if audio_path is not None:
267
  audio_length_in_s = get_audio_length(audio_path)
268
+ generated_audio_files[uid].append(audio_path)
269
  print(f"Audio Length: {audio_length_in_s}")
270
 
271
  print(image, video, audio)
272
  response = model.generate(prompts, audio, image, video, 200, temperature, top_p,
273
  audio_length_in_s=audio_length_in_s)
274
  print(response)
275
+ response_chat, response_outputs = parse_reponse(uid, response, audio_length_in_s)
276
  print('text_outputs: ', response_outputs)
277
  user_chat, user_outputs = parse_text(prompt_input, image_path, video_path, audio_path)
278
  chatbot.append((user_chat, response_chat))
 
323
 
324
  history = gr.State([])
325
  modality_cache = gr.State([])
326
+ uid = gr.State(uuid.uuid4())
327
 
328
  submitBtn.click(
329
  predict, [
330
+ uid,
331
  user_input,
332
  image_path,
333
  audio_path,
 
349
  show_progress=True
350
  )
351
 
352
+ submitBtn.click(reset_user_input, [uid], [user_input])
353
+ emptyBtn.click(reset_state, [uid], outputs=[
354
  image_path,
355
  audio_path,
356
  video_path,