File size: 8,089 Bytes
b71b6bf
 
 
 
 
8d1d12e
b40da7f
 
 
6c718bc
b40da7f
b71b6bf
466df4f
b71b6bf
7b4a4e7
b71b6bf
 
 
 
 
 
 
 
 
 
 
 
 
 
abd5868
 
 
e932cdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae4e904
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99578a5
 
 
 
 
 
 
 
 
6c718bc
00481fb
 
 
 
 
 
99578a5
 
00481fb
6c718bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99578a5
b71b6bf
 
ae4e904
 
 
b71b6bf
 
c14bad7
 
b71b6bf
c5fdebd
 
 
 
 
 
 
b7a6e89
c5fdebd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4b9351
bec6d41
 
 
 
c5fdebd
 
 
b71b6bf
c5fdebd
 
 
 
e1a2379
c5fdebd
b71b6bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e7972e
e932cdf
 
99578a5
241ee7b
 
b71b6bf
640f8a5
 
b71b6bf
c5f474f
 
 
 
 
 
 
 
 
 
 
640f8a5
 
 
 
 
 
 
 
 
48803b8
 
 
640f8a5
 
 
 
 
 
 
6c718bc
b71b6bf
 
 
6c718bc
b71b6bf
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import gradio as gr
import huggingface_hub
import os 
import subprocess
import threading
import shutil
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile
from moviepy.editor import VideoFileClip, AudioFileClip

# download model
huggingface_hub.snapshot_download(
    repo_id='ariesssxu/vta-ldm-clip4clip-v-large',
    local_dir='./ckpt/vta-ldm-clip4clip-v-large'
)

def stream_output(pipe):
    for line in iter(pipe.readline, ''):
        print(line, end='')

def print_directory_contents(path):
    for root, dirs, files in os.walk(path):
        level = root.replace(path, '').count(os.sep)
        indent = ' ' * 4 * (level)
        print(f"{indent}{os.path.basename(root)}/")
        subindent = ' ' * 4 * (level + 1)
        for f in files:
            print(f"{subindent}{f}")

# Print the ckpt directory contents
print_directory_contents('./ckpt')

def get_wav_files(path):
    wav_files = []  # Initialize an empty list to store the paths of .wav files
    for root, dirs, files in os.walk(path):
        level = root.replace(path, '').count(os.sep)
        indent = ' ' * 4 * (level)
        print(f"{indent}{os.path.basename(root)}/")
        subindent = ' ' * 4 * (level + 1)
        for f in files:
            file_path = os.path.join(root, f)
            if f.lower().endswith('.wav'):
                wav_files.append(file_path)  # Add .wav file paths to the list
                print(f"{subindent}{file_path}")
            else:
                print(f"{subindent}{f}")
    return wav_files  # Return the list of .wav file paths

def check_outputs_folder(folder_path):
    # Check if the folder exists
    if os.path.exists(folder_path) and os.path.isdir(folder_path):
        # Delete all contents inside the folder
        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)  # Remove file or link
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)  # Remove directory
            except Exception as e:
                print(f'Failed to delete {file_path}. Reason: {e}')
    else:
        print(f'The folder {folder_path} does not exist.')

def plot_spectrogram(wav_file, output_image):
    # Read the WAV file
    sample_rate, audio_data = wavfile.read(wav_file)

    # Check if audio_data is stereo (2 channels) and convert it to mono (1 channel) if needed
    if len(audio_data.shape) == 2:
        audio_data = audio_data.mean(axis=1)

    # Create a plot for the spectrogram
    plt.figure(figsize=(10, 2))
    plt.specgram(audio_data, Fs=sample_rate, NFFT=1024, noverlap=512, cmap='gray', aspect='auto')

    # Remove gridlines and ticks for a cleaner look
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])

    # Save the plot as an image file
    plt.savefig(output_image, bbox_inches='tight', pad_inches=0, dpi=300)
    plt.close

def merge_audio_to_video(input_vid, input_aud):
    # Load the video file
    video = VideoFileClip(input_vid)
    
    # Load the new audio file
    new_audio = AudioFileClip(input_aud)
    
    # Set the new audio to the video
    video_with_new_audio = video.set_audio(new_audio)
    
    # Save the result to a new file
    video_with_new_audio.write_videofile("output_video.mp4", codec='libx264', audio_codec='aac')

    return "output_video.mp4"

def infer(video_in):

    # check if 'outputs' dir exists and empty it if necessary
    check_outputs_folder('./outputs/tmp')
    
    # Need to find path to gradio temp vid from video input
    print(f"VIDEO IN PATH: {video_in}")
    # Get the directory name
    folder_path = os.path.dirname(video_in)

    # Path to the input video file
    input_video_path = video_in
    
    # Load the video file
    video = VideoFileClip(input_video_path)
    
    # Get the length of the video in seconds
    video_duration = int(video.duration)
    print(f"Video duration: {video_duration} seconds")
    
    # Check if the video duration is more than 10 seconds
    if video_duration > 10:
        # Cut the video to the first 10 seconds
        cut_video = video.subclip(0, 10)
        video_duration = 10
        
        # Extract the directory and filename
        dir_name = os.path.dirname(input_video_path)
        base_name = os.path.basename(input_video_path)
        
        # Generate the new filename
        new_base_name = base_name.replace(".mp4", "_10sec_cut.mp4")
        output_video_path = os.path.join(dir_name, new_base_name)
        
        # Save the cut video
        cut_video.write_videofile(output_video_path, codec='libx264', audio_codec='aac')
        print(f"Cut video saved as: {output_video_path}")
        video_in = output_video_path

        # Delete the original video file
        os.remove(input_video_path)
        print(f"Original video file {input_video_path} deleted.")
    else:
        print("Video is 10 seconds or shorter; no cutting needed.")

    # Execute the inference command
    command = ['python', 'inference_from_video.py', 
               '--original_args', 'ckpt/vta-ldm-clip4clip-v-large/summary.jsonl', 
               '--model', 'ckpt/vta-ldm-clip4clip-v-large/pytorch_model_2.bin', 
               '--data_path', folder_path,
               '--max_duration', f"{video_duration}"
              ]
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1)

    # Create threads to handle stdout and stderr
    stdout_thread = threading.Thread(target=stream_output, args=(process.stdout,))
    stderr_thread = threading.Thread(target=stream_output, args=(process.stderr,))

    # Start the threads
    stdout_thread.start()
    stderr_thread.start()

    # Wait for the process to complete and the threads to finish
    process.wait()
    stdout_thread.join()
    stderr_thread.join()

    print("Inference script finished with return code:", process.returncode)

    # Need to find where are the results stored, default should be "./outputs/tmp"
    # Print the outputs directory contents
    print_directory_contents('./outputs/tmp')
    wave_files = get_wav_files('./outputs/tmp')
    print(wave_files)
    plot_spectrogram(wave_files[0], 'spectrogram.png')
    final_merged_out = merge_audio_to_video(video_in, wave_files[0])
    return wave_files[0], 'spectrogram.png', final_merged_out


with gr.Blocks() as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# Video-to-Audio Generation with Hidden Alignment")
        gr.HTML("""
        <div style="display:flex;column-gap:4px;">
            <a href='https://sites.google.com/view/vta-ldm'>
                <img src='https://img.shields.io/badge/Project-Page-Green'>
            </a>
            <a href='https://huggingface.co/papers/2407.07464'>
                <img src='https://img.shields.io/badge/HF-Paper-red'>
            </a>
        </div>
        """)
        with gr.Row():
            with gr.Column():
                video_in = gr.Video(label='Video IN')
                submit_btn = gr.Button("Submit")
                gr.Examples(
                    examples = [
                        ["./examples/lion_gt.mp4"],
                        ["./examples/ice_gt.mp4"],
                        ["./examples/seashore.mp4"],
                        ["./examples/typewriter.mp4"],
                        ["./examples/tennis_gt.mp4"],
                        ["./examples/chew.mp4"],
                    ],
                    inputs = [video_in]
                )
            with gr.Column():
                output_sound = gr.Audio(label="Audio OUT")
                output_spectrogram = gr.Image(label='Spectrogram')
                merged_out = gr.Video(label="Merged video + generated audio")
        
    submit_btn.click(
        fn = infer,
        inputs = [video_in],
        outputs = [output_sound, output_spectrogram, merged_out],
        show_api = False
    )
demo.launch(show_api=False, show_error=True)