Spaces:
Runtime error
Runtime error
File size: 8,089 Bytes
b71b6bf 8d1d12e b40da7f 6c718bc b40da7f b71b6bf 466df4f b71b6bf 7b4a4e7 b71b6bf abd5868 e932cdf ae4e904 99578a5 6c718bc 00481fb 99578a5 00481fb 6c718bc 99578a5 b71b6bf ae4e904 b71b6bf c14bad7 b71b6bf c5fdebd b7a6e89 c5fdebd f4b9351 bec6d41 c5fdebd b71b6bf c5fdebd e1a2379 c5fdebd b71b6bf 0e7972e e932cdf 99578a5 241ee7b b71b6bf 640f8a5 b71b6bf c5f474f 640f8a5 48803b8 640f8a5 6c718bc b71b6bf 6c718bc b71b6bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
import gradio as gr
import huggingface_hub
import os
import subprocess
import threading
import shutil
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile
from moviepy.editor import VideoFileClip, AudioFileClip
# download model
huggingface_hub.snapshot_download(
repo_id='ariesssxu/vta-ldm-clip4clip-v-large',
local_dir='./ckpt/vta-ldm-clip4clip-v-large'
)
def stream_output(pipe):
for line in iter(pipe.readline, ''):
print(line, end='')
def print_directory_contents(path):
for root, dirs, files in os.walk(path):
level = root.replace(path, '').count(os.sep)
indent = ' ' * 4 * (level)
print(f"{indent}{os.path.basename(root)}/")
subindent = ' ' * 4 * (level + 1)
for f in files:
print(f"{subindent}{f}")
# Print the ckpt directory contents
print_directory_contents('./ckpt')
def get_wav_files(path):
wav_files = [] # Initialize an empty list to store the paths of .wav files
for root, dirs, files in os.walk(path):
level = root.replace(path, '').count(os.sep)
indent = ' ' * 4 * (level)
print(f"{indent}{os.path.basename(root)}/")
subindent = ' ' * 4 * (level + 1)
for f in files:
file_path = os.path.join(root, f)
if f.lower().endswith('.wav'):
wav_files.append(file_path) # Add .wav file paths to the list
print(f"{subindent}{file_path}")
else:
print(f"{subindent}{f}")
return wav_files # Return the list of .wav file paths
def check_outputs_folder(folder_path):
# Check if the folder exists
if os.path.exists(folder_path) and os.path.isdir(folder_path):
# Delete all contents inside the folder
for filename in os.listdir(folder_path):
file_path = os.path.join(folder_path, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove file or link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove directory
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
else:
print(f'The folder {folder_path} does not exist.')
def plot_spectrogram(wav_file, output_image):
# Read the WAV file
sample_rate, audio_data = wavfile.read(wav_file)
# Check if audio_data is stereo (2 channels) and convert it to mono (1 channel) if needed
if len(audio_data.shape) == 2:
audio_data = audio_data.mean(axis=1)
# Create a plot for the spectrogram
plt.figure(figsize=(10, 2))
plt.specgram(audio_data, Fs=sample_rate, NFFT=1024, noverlap=512, cmap='gray', aspect='auto')
# Remove gridlines and ticks for a cleaner look
plt.grid(False)
plt.xticks([])
plt.yticks([])
# Save the plot as an image file
plt.savefig(output_image, bbox_inches='tight', pad_inches=0, dpi=300)
plt.close
def merge_audio_to_video(input_vid, input_aud):
# Load the video file
video = VideoFileClip(input_vid)
# Load the new audio file
new_audio = AudioFileClip(input_aud)
# Set the new audio to the video
video_with_new_audio = video.set_audio(new_audio)
# Save the result to a new file
video_with_new_audio.write_videofile("output_video.mp4", codec='libx264', audio_codec='aac')
return "output_video.mp4"
def infer(video_in):
# check if 'outputs' dir exists and empty it if necessary
check_outputs_folder('./outputs/tmp')
# Need to find path to gradio temp vid from video input
print(f"VIDEO IN PATH: {video_in}")
# Get the directory name
folder_path = os.path.dirname(video_in)
# Path to the input video file
input_video_path = video_in
# Load the video file
video = VideoFileClip(input_video_path)
# Get the length of the video in seconds
video_duration = int(video.duration)
print(f"Video duration: {video_duration} seconds")
# Check if the video duration is more than 10 seconds
if video_duration > 10:
# Cut the video to the first 10 seconds
cut_video = video.subclip(0, 10)
video_duration = 10
# Extract the directory and filename
dir_name = os.path.dirname(input_video_path)
base_name = os.path.basename(input_video_path)
# Generate the new filename
new_base_name = base_name.replace(".mp4", "_10sec_cut.mp4")
output_video_path = os.path.join(dir_name, new_base_name)
# Save the cut video
cut_video.write_videofile(output_video_path, codec='libx264', audio_codec='aac')
print(f"Cut video saved as: {output_video_path}")
video_in = output_video_path
# Delete the original video file
os.remove(input_video_path)
print(f"Original video file {input_video_path} deleted.")
else:
print("Video is 10 seconds or shorter; no cutting needed.")
# Execute the inference command
command = ['python', 'inference_from_video.py',
'--original_args', 'ckpt/vta-ldm-clip4clip-v-large/summary.jsonl',
'--model', 'ckpt/vta-ldm-clip4clip-v-large/pytorch_model_2.bin',
'--data_path', folder_path,
'--max_duration', f"{video_duration}"
]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1)
# Create threads to handle stdout and stderr
stdout_thread = threading.Thread(target=stream_output, args=(process.stdout,))
stderr_thread = threading.Thread(target=stream_output, args=(process.stderr,))
# Start the threads
stdout_thread.start()
stderr_thread.start()
# Wait for the process to complete and the threads to finish
process.wait()
stdout_thread.join()
stderr_thread.join()
print("Inference script finished with return code:", process.returncode)
# Need to find where are the results stored, default should be "./outputs/tmp"
# Print the outputs directory contents
print_directory_contents('./outputs/tmp')
wave_files = get_wav_files('./outputs/tmp')
print(wave_files)
plot_spectrogram(wave_files[0], 'spectrogram.png')
final_merged_out = merge_audio_to_video(video_in, wave_files[0])
return wave_files[0], 'spectrogram.png', final_merged_out
with gr.Blocks() as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# Video-to-Audio Generation with Hidden Alignment")
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href='https://sites.google.com/view/vta-ldm'>
<img src='https://img.shields.io/badge/Project-Page-Green'>
</a>
<a href='https://huggingface.co/papers/2407.07464'>
<img src='https://img.shields.io/badge/HF-Paper-red'>
</a>
</div>
""")
with gr.Row():
with gr.Column():
video_in = gr.Video(label='Video IN')
submit_btn = gr.Button("Submit")
gr.Examples(
examples = [
["./examples/lion_gt.mp4"],
["./examples/ice_gt.mp4"],
["./examples/seashore.mp4"],
["./examples/typewriter.mp4"],
["./examples/tennis_gt.mp4"],
["./examples/chew.mp4"],
],
inputs = [video_in]
)
with gr.Column():
output_sound = gr.Audio(label="Audio OUT")
output_spectrogram = gr.Image(label='Spectrogram')
merged_out = gr.Video(label="Merged video + generated audio")
submit_btn.click(
fn = infer,
inputs = [video_in],
outputs = [output_sound, output_spectrogram, merged_out],
show_api = False
)
demo.launch(show_api=False, show_error=True) |