File size: 8,802 Bytes
0ecd9fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
# Hacked together using the code from https://github.com/nikhilsinghmus/image2reverb

import os, types
import numpy as np
import gradio as gr
import soundfile as sf
import scipy
import librosa.display
from PIL import Image

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from pytorch_lightning import Trainer

from image2reverb.model import Image2Reverb
from image2reverb.stft import STFT


predicted_ir = None
predicted_spectrogram = None
predicted_depthmap = None


def test_step(self, batch, batch_idx):
    spec, label, paths = batch
    examples = [os.path.splitext(os.path.basename(s))[0] for _, s in zip(*paths)]

    f, img = self.enc.forward(label)

    shape = (
        f.shape[0],
        (self._latent_dimension - f.shape[1]) if f.shape[1] < self._latent_dimension else f.shape[1],
        f.shape[2],
        f.shape[3]
    )
    z = torch.cat((f, torch.randn(shape, device=model.device)), 1)

    fake_spec = self.g(z)

    stft = STFT()
    y_f = [stft.inverse(s.squeeze()) for s in fake_spec]

    # TODO: bit hacky
    global predicted_ir, predicted_spectrogram, predicted_depthmap
    predicted_ir = y_f[0]

    s = fake_spec.squeeze().cpu().numpy()
    predicted_spectrogram = np.exp((((s + 1) * 0.5) * 19.5) - 17.5) - 1e-8

    img = (img + 1) * 0.5
    predicted_depthmap = img.cpu().squeeze().permute(1, 2, 0)[:,:,-1].squeeze().numpy()

    return {"test_audio": y_f, "test_examples": examples}


def test_epoch_end(self, outputs):
    if not self.test_callback:
        return

    examples = []
    audio = []

    for output in outputs:
        for i in range(len(output["test_examples"])):
            audio.append(output["test_audio"][i])
            examples.append(output["test_examples"][i])

    self.test_callback(examples, audio)


checkpoint_path = "./checkpoints/image2reverb_f22.ckpt"
encoder_path = None
depthmodel_path = "./checkpoints/mono_odom_640x192"
constant_depth = None
latent_dimension = 512

model = Image2Reverb(encoder_path, depthmodel_path)
m = torch.load(checkpoint_path, map_location=model.device)
model.load_state_dict(m["state_dict"])

model.test_step = types.MethodType(test_step, model)
model.test_epoch_end = types.MethodType(test_epoch_end, model)

image_transforms = transforms.Compose([
    transforms.Resize([224, 224], transforms.functional.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


class Image2ReverbDemoDataset(Dataset):
    def __init__(self, image):
        self.image = Image.fromarray(image)
        self.stft = STFT()

    def __getitem__(self, index):
        img_tensor = image_transforms(self.image.convert("RGB"))
        return torch.zeros(1, int(5.94 * 22050)), img_tensor, ("", "")

    def __len__(self):
        return 1

    def name(self):
        return "Image2ReverbDemo"


def convolve(audio, reverb):
    # convolve audio with reverb
    wet_audio = np.concatenate((audio, np.zeros(reverb.shape)))
    wet_audio = scipy.signal.oaconvolve(wet_audio, reverb, "full")[:len(wet_audio)]

    # normalize audio to roughly -1 dB peak and remove DC offset
    wet_audio /= np.max(np.abs(wet_audio))
    wet_audio -= np.mean(wet_audio)
    wet_audio *= 0.9
    return wet_audio


def predict(image, audio):
    # image = numpy (height, width, channels)
    # audio = tuple (sample_rate, frames) or (sample_rate, (frames, channels))

    test_set = Image2ReverbDemoDataset(image)
    test_loader = torch.utils.data.DataLoader(test_set, num_workers=0, batch_size=1)
    trainer = Trainer(limit_test_batches=1)
    trainer.test(model, test_loader, verbose=True)

    # depthmap output
    depthmap_fig = plt.figure()
    plt.imshow(predicted_depthmap)
    plt.close()

    # spectrogram output
    spectrogram_fig = plt.figure()
    librosa.display.specshow(predicted_spectrogram, sr=22050, x_axis="time", y_axis="hz")
    plt.close()

    # plot the IR as a waveform
    waveform_fig = plt.figure()
    librosa.display.waveshow(predicted_ir, sr=22050, alpha=0.5)
    plt.close()

    # output audio as 16-bit signed integer
    ir = (22050, (predicted_ir * 32767).astype(np.int16))

    sample_rate, original_audio = audio

    # incoming audio is 16-bit signed integer, convert to float and normalize
    original_audio = original_audio.astype(np.float32) / 32768.0
    original_audio /= np.max(np.abs(original_audio))

    # resample reverb to sample_rate first, also normalize
    reverb = predicted_ir.copy()
    reverb = scipy.signal.resample_poly(reverb, up=sample_rate, down=22050)
    reverb /= np.max(np.abs(reverb))

    # stereo?
    if len(original_audio.shape) > 1:
        wet_left = convolve(original_audio[:, 0], reverb)
        wet_right = convolve(original_audio[:, 1], reverb)
        wet_audio = np.concatenate([wet_left[:, None], wet_right[:, None]], axis=1)
    else:
        wet_audio = convolve(original_audio, reverb)

    # 50% dry-wet mix
    mixed_audio = wet_audio * 0.5
    mixed_audio[:len(original_audio), ...] += original_audio * 0.9 * 0.5

    # convert back to 16-bit signed integer
    wet_audio = (wet_audio * 32767).astype(np.int16)
    mixed_audio = (mixed_audio * 32767).astype(np.int16)

    convolved_audio_100 = (sample_rate, wet_audio)
    convolved_audio_50 = (sample_rate, mixed_audio)

    return depthmap_fig, spectrogram_fig, waveform_fig, ir, convolved_audio_100, convolved_audio_50


title = "Image2Reverb: Cross-Modal Reverb Impulse Response Synthesis"

description = """
<b>Image2Reverb</b> predicts the acoustic reverberation of a given environment from a 2D image. <a href="https://arxiv.org/abs/2103.14201">Read the paper</a>

How to use: Choose an image of a room or other environment and an audio file.
The model will predict what the reverb of the room sounds like and applies this to the audio file.

First, the image is resized to 224×224. The monodepth model is used to predict a depthmap, which is added as an
additional channel to the image input. A ResNet-based encoder then converts the image into features, and
finally a GAN predicts the spectrogram of the reverb's impulse response.

<center><img src="file/model.jpg" width="870" height="297" alt="model architecture"></center>

The predicted impulse response is mono 22050 kHz. It is upsampled to the sampling rate of the audio
file and applied to both channels if the audio is stereo.
Generating the impulse response involves a certain amount of randomness, making it sound a little
different every time you try it.
"""

article = """
<div style='margin:20px auto;'>

<p>Based on original work by Nikhil Singh, Jeff Mentch, Jerry Ng, Matthew Beveridge, Iddo Drori.
<a href="https://web.media.mit.edu/~nsingh1/image2reverb/">Project Page</a> |
<a href="https://arxiv.org/abs/2103.14201">Paper</a> |
<a href="https://github.com/nikhilsinghmus/image2reverb">GitHub</a></p>

<pre>
@InProceedings{Singh_2021_ICCV,
    author    = {Singh, Nikhil and Mentch, Jeff and Ng, Jerry and Beveridge, Matthew and Drori, Iddo},
    title     = {Image2Reverb: Cross-Modal Reverb Impulse Response Synthesis},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {286-295}
}
</pre>

<p>🌠 Example images from <a href="https://web.media.mit.edu/~nsingh1/image2reverb/">the original project page</a>.</p>

<p>🎶 Example sound from <a href="https://freesound.org/people/ashesanddreams/sounds/610414/">Ashes and Dreams @ freesound.org</a> (CC BY 4.0 license). This is a mono 48 kHz recording that has no reverb on it.</p>

</div>
"""

audio_example = "examples/ashesanddreams.wav"

examples = [
    ["examples/input.4e2f71f6.png", audio_example],
    ["examples/input.321eef38.png", audio_example],
    ["examples/input.2238dc21.png", audio_example],
    ["examples/input.4d280b40.png", audio_example],
    ["examples/input.0c3f5013.png", audio_example],
    ["examples/input.98773b90.png", audio_example],
    ["examples/input.ac61500f.png", audio_example],
    ["examples/input.5416407f.png", audio_example],
]

gr.Interface(
    fn=predict,
    inputs=[
        gr.inputs.Image(label="Upload Image"),
        gr.inputs.Audio(label="Upload Audio", source="upload"),
    ],
    outputs=[
        gr.Plot(label="Depthmap"),
        gr.Plot(label="Impulse Response Spectrogram"),
        gr.Plot(label="Impulse Response Waveform"),
        gr.outputs.Audio(label="Impulse Response"),
        gr.outputs.Audio(label="Output Audio (100% Wet)"),
        gr.outputs.Audio(label="Output Audio (50% Dry, 50% Wet)"),
    ],
    title=title,
    description=description,
    article=article,
    examples=examples,
).launch()