bpiyush commited on
Commit
eafbf97
1 Parent(s): b6e4b21

Upload folder using huggingface_hub

Browse files
shared/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ This folder shall have code utilities shared across different tasks.
shared/scripts/upload_data.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Uploads dataset to huggingface datasets."""
2
+ import os
3
+ import sys
4
+
5
+ import pandas as pd
6
+ import numpy as np
7
+
8
+ from huggingface_hub import HfApi
9
+ import shared.utils as su
10
+ from sound_of_water.data.csv_loader import (
11
+ load_csv_sound_of_water,
12
+ configure_paths_sound_of_water,
13
+ )
14
+
15
+
16
+ if __name__ == "__main__":
17
+ api = HfApi()
18
+
19
+ data_root = "/work/piyush/from_nfs2/datasets/SoundOfWater"
20
+ repo_id = "bpiyush/sound-of-water"
21
+
22
+ save_splits = False
23
+ if save_splits:
24
+ # Load CSV
25
+ paths = configure_paths_sound_of_water(data_root)
26
+ df = load_csv_sound_of_water(paths)
27
+ del df["video_clip_path"]
28
+ del df["audio_clip_path"]
29
+ del df["box_path"]
30
+ del df["mask_path"]
31
+
32
+ # Splits
33
+ train_ids = su.io.load_txt(os.path.join(data_root, "splits/train.txt"))
34
+ df_train = df[df.item_id.isin(train_ids)]
35
+ df_train["file_name"] = df_train["item_id"].apply(lambda x: f"videos/{x}.mp4")
36
+ df_train.to_csv(os.path.join(data_root, "splits/train.csv"), index=False)
37
+ print(" [:::] Train split saved.")
38
+
39
+ test_I_ids = su.io.load_txt(os.path.join(data_root, "splits/test_I.txt"))
40
+ df_test_I = df[df.item_id.isin(test_I_ids)]
41
+ df_test_I["file_name"] = df_test_I["item_id"].apply(lambda x: f"videos/{x}.mp4")
42
+ df_test_I.to_csv(os.path.join(data_root, "splits/test_I.csv"), index=False)
43
+ print(" [:::] Test I split saved.")
44
+
45
+ test_II_ids = su.io.load_txt(os.path.join(data_root, "splits/test_II.txt"))
46
+ df_test_II = df[df.item_id.isin(test_II_ids)]
47
+ df_test_II["file_name"] = df_test_II["item_id"].apply(lambda x: f"videos/{x}.mp4")
48
+ df_test_II.to_csv(os.path.join(data_root, "splits/test_II.csv"), index=False)
49
+ print(" [:::] Test II split saved.")
50
+
51
+ test_III_ids = su.io.load_txt(os.path.join(data_root, "splits/test_III.txt"))
52
+ df_test_III = df[df.item_id.isin(test_III_ids)]
53
+ df_test_III["file_name"] = df_test_III["item_id"].apply(lambda x: f"videos/{x}.mp4")
54
+ df_test_III.to_csv(os.path.join(data_root, "splits/test_III.csv"), index=False)
55
+ print(" [:::] Test III split saved.")
56
+
57
+
58
+ create_splits = False
59
+ if create_splits:
60
+ train_ids = su.io.load_txt(os.path.join(data_root, "splits/train.txt"))
61
+ train_ids = np.unique(train_ids)
62
+
63
+ test_I_ids = su.io.load_txt(os.path.join(data_root, "splits/test_I.txt"))
64
+ test_I_ids = np.unique(test_I_ids)
65
+
66
+ other_ids = np.array(
67
+ list(set(df.item_id.unique()) - set(train_ids) - set(test_I_ids))
68
+ )
69
+ sub_df = df[~df.item_id.isin(set(train_ids) | set(test_I_ids))]
70
+ X = sub_df[
71
+ (sub_df.visibility != "transparent") & (sub_df["shape"].isin(["cylindrical", "semiconical"]))
72
+ ]
73
+ test_II_ids = list(X.item_id.unique())
74
+ assert set(test_II_ids).intersection(set(train_ids)) == set()
75
+ assert set(test_II_ids).intersection(set(test_I_ids)) == set()
76
+ su.io.save_txt(test_II_ids, os.path.join(data_root, "splits/test_II.txt"))
77
+
78
+ X = sub_df[
79
+ (sub_df.visibility.isin(["transparent", "opaque"])) & \
80
+ (sub_df["shape"].isin(["cylindrical", "semiconical", "bottleneck"]))
81
+ ]
82
+ test_III_ids = list(X.item_id.unique())
83
+ assert set(test_III_ids).intersection(set(train_ids)) == set()
84
+ assert set(test_III_ids).intersection(set(test_I_ids)) == set()
85
+ assert set(test_III_ids).intersection(set(test_II_ids)) != set()
86
+ su.io.save_txt(test_III_ids, os.path.join(data_root, "splits/test_III.txt"))
87
+
88
+ upload_file = True
89
+ if upload_file:
90
+ file = "README.md"
91
+ print(f" [:::] Uploading file: {file}")
92
+ api.upload_file(
93
+ path_or_fileobj=os.path.join(data_root, file),
94
+ path_in_repo=file,
95
+ repo_id=repo_id,
96
+ repo_type="dataset",
97
+ )
98
+
99
+ upload_folder = False
100
+ if upload_folder:
101
+ # Upload splits folder
102
+ foldername = "annotations"
103
+ print(f" [:::] Uploading folder: {foldername}")
104
+ api.upload_folder(
105
+ folder_path=os.path.join(data_root, foldername),
106
+ path_in_repo=foldername, # Upload to a specific folder
107
+ repo_id=repo_id,
108
+ repo_type="dataset",
109
+ )
shared/utils/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shared.utils.paths as paths
2
+ import shared.utils.log as log
3
+ import shared.utils.io as io
4
+ import shared.utils.audio as audio
5
+ import shared.utils.image as image
6
+ import shared.utils.av as av
7
+ import shared.utils.pandas_utils as pd_utils
8
+ import shared.utils.visualize as visualize
9
+ import shared.utils.metrics as metrics
10
+ import shared.utils.misc as misc
11
+ import shared.utils.keypoint_matching as keypoint_matching
12
+ import shared.utils.physics as physics
shared/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (768 Bytes). View file
 
shared/utils/__pycache__/audio.cpython-39.pyc ADDED
Binary file (6.01 kB). View file
 
shared/utils/__pycache__/av.cpython-39.pyc ADDED
Binary file (2.72 kB). View file
 
shared/utils/__pycache__/image.cpython-39.pyc ADDED
Binary file (1.96 kB). View file
 
shared/utils/__pycache__/io.cpython-39.pyc ADDED
Binary file (4.76 kB). View file
 
shared/utils/__pycache__/keypoint_matching.cpython-39.pyc ADDED
Binary file (10 kB). View file
 
shared/utils/__pycache__/log.cpython-39.pyc ADDED
Binary file (2.15 kB). View file
 
shared/utils/__pycache__/metrics.cpython-39.pyc ADDED
Binary file (10.6 kB). View file
 
shared/utils/__pycache__/misc.cpython-39.pyc ADDED
Binary file (4.28 kB). View file
 
shared/utils/__pycache__/pandas_utils.cpython-39.pyc ADDED
Binary file (3.2 kB). View file
 
shared/utils/__pycache__/paths.cpython-39.pyc ADDED
Binary file (918 Bytes). View file
 
shared/utils/__pycache__/physics.cpython-39.pyc ADDED
Binary file (7.17 kB). View file
 
shared/utils/__pycache__/visualize.cpython-39.pyc ADDED
Binary file (54.5 kB). View file
 
shared/utils/audio.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio utils"""
2
+ import librosa
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+
6
+
7
+ def load_audio(audio_path: str, sr: int = None, max_duration: int = 10., start: int = 0, stop: int = None):
8
+ """Loads audio and pads/trims it to max_duration"""
9
+ data, sr = librosa.load(audio_path, sr=sr)
10
+
11
+ if stop is not None:
12
+ start = int(start * sr)
13
+ stop = int(stop * sr)
14
+ data = data[start:stop]
15
+
16
+ # Convert to mono
17
+ if len(data.shape) > 1:
18
+ data = np.mean(data, axis=1)
19
+
20
+ n_frames = int(max_duration * sr)
21
+ if len(data) > n_frames:
22
+ data = data[:n_frames]
23
+ elif len(data) < n_frames:
24
+ data = np.pad(data, (0, n_frames - len(data)), "constant")
25
+ return data, sr
26
+
27
+
28
+ # def compute_spectrogram(data: np.ndarray, sr: int):
29
+ # D = librosa.stft(data) # STFT of y
30
+ # S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
31
+ # return S_db
32
+
33
+
34
+ def compute_spec_freq_mean(S_db: np.ndarray, eps=1e-5):
35
+ # Compute mean of spectrogram over frequency axis
36
+ S_db_normalized = (S_db - S_db.mean(axis=1)[:, None]) / (S_db.std(axis=1)[:, None] + eps)
37
+ S_db_over_time = S_db_normalized.sum(axis=0)
38
+ return S_db_over_time
39
+
40
+
41
+ def process_audiofile(audio_path, functions=["load_audio", "compute_spectrogram", "compute_spec_freq_mean"]):
42
+ """Processes audio file with a list of functions"""
43
+ data, sr = load_audio(audio_path)
44
+ for function in functions:
45
+ if function == "load_audio":
46
+ pass
47
+ elif function == "compute_spectrogram":
48
+ data = compute_spectrogram(data, sr)
49
+ elif function == "compute_spec_freq_mean":
50
+ data = compute_spec_freq_mean(data)
51
+ else:
52
+ raise ValueError(f"Unknown function {function}")
53
+ return data
54
+
55
+
56
+
57
+ """PyDub's silence detection is based on the energy of the audio signal."""
58
+ import numpy as np
59
+
60
+
61
+ def sigmoid(x):
62
+ return 1 / (1 + np.exp(-x))
63
+
64
+
65
+ class SilenceDetector:
66
+
67
+
68
+ def __init__(self, silence_thresh=-36) -> None:
69
+ self.silence_thresh = silence_thresh
70
+
71
+ def __call__(self, audio_path: str, start=None, end=None):
72
+
73
+ import pydub
74
+ from pydub.utils import db_to_float
75
+
76
+ try:
77
+ waveform = pydub.AudioSegment.from_file(audio_path)
78
+ except:
79
+ print("Error loading audio file: ", audio_path)
80
+ return 100.0
81
+
82
+ start_ms = int(start * 1000) if start else 0
83
+ end_ms = int(end * 1000) if end else len(waveform)
84
+ waveform = waveform[start_ms:end_ms]
85
+
86
+ # convert silence threshold to a float value (so we can compare it to rms)
87
+ silence_thresh = db_to_float(self.silence_thresh) * waveform.max_possible_amplitude
88
+
89
+ if waveform.rms == 0:
90
+ return 100.0
91
+
92
+ silence_prob = sigmoid((silence_thresh - waveform.rms) / waveform.rms)
93
+
94
+ # return waveform.rms <= silence_thresh
95
+ return np.round(100 * silence_prob, 2)
96
+
97
+
98
+ def frequency_bin_to_value(bin_index, sr, n_fft):
99
+ return int(bin_index * sr / n_fft)
100
+
101
+
102
+ def time_bin_to_value(bin_index, hop_length, sr):
103
+ return (bin_index) * (hop_length / sr)
104
+
105
+
106
+ def add_time_annotations(ax, nt_bins, hop_length, sr, skip=50):
107
+ # Show time (s) values on the x-axis
108
+ t_bins = np.arange(nt_bins)
109
+ t_vals = np.round(np.array([time_bin_to_value(tb, hop_length, sr) for tb in t_bins]), 1)
110
+ try:
111
+ ax.set_xticks(t_bins[::skip], t_vals[::skip])
112
+ except:
113
+ pass
114
+ ax.set_xlabel("Time (s)")
115
+
116
+
117
+ def add_freq_annotations(ax, nf_bins, sr, n_fft, skip=50):
118
+ f_bins = np.arange(nf_bins)
119
+ f_vals = np.array([frequency_bin_to_value(fb, sr, n_fft) for fb in f_bins])
120
+ try:
121
+ ax.set_yticks(f_bins[::skip], f_vals[::skip])
122
+ except:
123
+ pass
124
+ # ax.set_yticks(f_bins[::skip])
125
+ # ax.set_yticklabels(f_vals[::skip])
126
+ ax.set_ylabel("Frequency (Hz)")
127
+
128
+
129
+ def show_single_spectrogram(
130
+ spec,
131
+ sr,
132
+ n_fft,
133
+ hop_length,
134
+ ax=None,
135
+ fig=None,
136
+ figsize=(10, 2),
137
+ cmap="viridis",
138
+ colorbar=True,
139
+ show=True,
140
+ format='%+2.0f dB',
141
+ xlabel='Time (s)',
142
+ ylabel="Frequency (Hz)",
143
+ title=None,
144
+ show_dom_freq=False,
145
+ ):
146
+
147
+ if ax is None:
148
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
149
+ axim = ax.imshow(spec, origin="lower", cmap=cmap)
150
+
151
+ # Show frequency (Hz) values on y-axis
152
+ nf_bins, nt_bins = spec.shape
153
+
154
+ if "frequency" in ylabel.lower():
155
+ # Add frequency annotation
156
+ add_freq_annotations(ax, nf_bins, sr, n_fft)
157
+
158
+ # Add time annotation
159
+ add_time_annotations(ax, nt_bins, hop_length, sr)
160
+
161
+ ax.set_title(title)
162
+ ax.set_xlabel(xlabel)
163
+ ax.set_ylabel(ylabel)
164
+
165
+ if colorbar:
166
+ fig.colorbar(axim, ax=ax, orientation='vertical', fraction=0.01, format=format)
167
+
168
+ if show_dom_freq:
169
+ fmax = spec.argmax(axis=0)
170
+ ax.scatter(np.arange(spec.shape[1]), fmax, color="white", s=0.2)
171
+
172
+ if show:
173
+ plt.show()
174
+
175
+
176
+ def compute_spectrogram(y, n_fft, hop_length, margin, n_mels=None):
177
+ # STFT
178
+ D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
179
+
180
+ # Run HPSS
181
+ S, _ = librosa.decompose.hpss(D, margin=margin)
182
+
183
+ # DB
184
+ S = librosa.amplitude_to_db(np.abs(S), ref=np.max)
185
+
186
+ if n_mels is not None:
187
+ S = librosa.feature.melspectrogram(S=S, n_mels=n_mels)
188
+
189
+ return S
190
+
191
+
192
+ def show_spectrogram(S, sr, n_fft=512, hop_length=256, figsize=(10, 3), n_mels=None, ax=None, show=True):
193
+ if ax is None:
194
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
195
+ y_axis = "mel" if n_mels is not None else "linear"
196
+ librosa.display.specshow(
197
+ S,
198
+ sr=sr,
199
+ hop_length=hop_length,
200
+ n_fft=n_fft,
201
+ y_axis=y_axis,
202
+ x_axis='time',
203
+ ax=ax,
204
+ )
205
+ ax.set_title("LogSpectrogram" if n_mels is None else "LogMelSpectrogram")
206
+ if show:
207
+ plt.show()
208
+
209
+
210
+ def show_frame_and_spectrogram(frame, S, sr, figsize=(12, 4), show=True, axes=None, **spec_args):
211
+ if axes is None:
212
+ fig, axes = plt.subplots(1, 2, figsize=figsize, gridspec_kw={"width_ratios": [0.2, 0.8]})
213
+ ax = axes[0]
214
+ ax.imshow(frame)
215
+ ax.set_xticks([])
216
+ ax.set_yticks([])
217
+
218
+ ax = axes[1]
219
+ show_spectrogram(S=S, sr=sr, ax=ax, show=False, **spec_args)
220
+
221
+ plt.tight_layout()
222
+
223
+ if show:
224
+ plt.show()
shared/utils/av.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio-visual helper functions."""
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def save_video_with_audio(video, audio, output_path):
8
+ """
9
+ Saves a video file with audio.
10
+
11
+ Args:
12
+ video (np.ndarray): Video frames.
13
+ audio (np.ndarray): Audio samples.
14
+ output_path (str): Output path.
15
+ """
16
+
17
+ # check the correct shape and format for audio
18
+ assert isinstance(audio, np.ndarray)
19
+ assert len(audio.shape) == 2
20
+ assert audio.shape[1] in [1, 2]
21
+
22
+ # create video writer
23
+ video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 30, (video.shape[2], video.shape[1]))
24
+ # write the image frames to the video
25
+ for frame in video:
26
+ video_writer.write(frame)
27
+ # add the audio data to the video
28
+ video_writer.write(audio)
29
+ # release the VideoWriter object
30
+ video_writer.release()
31
+
32
+
33
+ def save_video_from_image_sequence_and_audio(sequence, audio, save_path, video_fps=15, audio_fps=22100):
34
+ from moviepy.editor import VideoClip, AudioClip, ImageSequenceClip
35
+ from moviepy.audio.AudioClip import AudioArrayClip
36
+
37
+ assert isinstance(sequence, list) and isinstance(audio, (np.ndarray, torch.Tensor))
38
+ assert len(audio.shape) == 2 and audio.shape[1] in [1, 2]
39
+
40
+ video_duration = len(sequence) / video_fps
41
+ audio_duration = len(audio) / audio_fps
42
+ # # print(f"Video duration: {video_duration:.2f}s, audio duration: {audio_duration:.2f}s")
43
+ # assert video_duration == audio_duration, \
44
+ # f"Video duration ({video_duration}) and audio duration ({audio_duration}) do not match."
45
+
46
+ video_clip = ImageSequenceClip(sequence, fps=video_fps)
47
+ audio_clip = AudioArrayClip(audio, fps=audio_fps)
48
+ video_clip = video_clip.set_audio(audio_clip)
49
+ # video_clip.write_videofile(save_path, verbose=True, logger=None, fps=video_fps, audio_fps=audio_fps)
50
+ video_clip.write_videofile(save_path, verbose=False, logger=None)
51
+
52
+
53
+ import cv2, os
54
+ import argparse
55
+ import numpy as np
56
+ from glob import glob
57
+ import librosa
58
+ import subprocess
59
+
60
+
61
+ def generate_video(args):
62
+
63
+ frames = glob('{}/*.png'.format(args.input_dir))
64
+ print("Total frames = ", len(frames))
65
+
66
+ frames.sort(key = lambda x: int(x.split("/")[-1].split(".")[0]))
67
+
68
+ img = cv2.imread(frames[0])
69
+ print(img.shape)
70
+ fname = 'inference.avi'
71
+ video = cv2.VideoWriter(
72
+ fname, cv2.VideoWriter_fourcc(*'DIVX'), args.fps, (img.shape[1], img.shape[0]),
73
+ )
74
+
75
+ for i in range(len(frames)):
76
+ img = cv2.imread(frames[i])
77
+ video.write(img)
78
+
79
+ video.release()
80
+
81
+ output_file_name = args.output_video
82
+
83
+ no_sound_video = output_file_name + '_nosound.mp4'
84
+ subprocess.call('ffmpeg -hide_banner -loglevel panic -i %s -c copy -an -strict -2 %s' % (fname, no_sound_video), shell=True)
85
+
86
+ if args.audio_file is not None:
87
+ video_output = output_file_name + '.mp4'
88
+ subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 %s' %
89
+ (args.audio_file, no_sound_video, video_output), shell=True)
90
+
91
+ os.remove(no_sound_video)
92
+
93
+ os.remove(fname)
shared/utils/classification.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helper functions for classification tasks."""
2
+ import matplotlib.pyplot as plt
3
+ import pandas as pd
4
+ import numpy as np
5
+
6
+
7
+ def plot_metric_curve(
8
+ xvalues, yvalues, thresholds, title=None,
9
+ figsize=(8, 7), show_thresholds=True, show_legend=True,
10
+ ylabel='X', xlabel='Y', ax=None, text_delta=0.01,
11
+ label="Metric Curve", color="royalblue", show=False,
12
+ fill=None,
13
+ ):
14
+ """Plot a metric curve, e.g., PR curve or ROC curve."""
15
+
16
+ if ax is None:
17
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
18
+
19
+ ax.grid(alpha=0.3)
20
+ ax.set_title(title)
21
+ ax.set_ylabel(ylabel)
22
+ ax.set_xlabel(xlabel)
23
+
24
+ ax.plot(xvalues, yvalues, marker='o', label=label, color=color)
25
+ ax.set_xlim(-0.08, 1.08)
26
+ ax.set_ylim(-0.08, 1.08)
27
+
28
+ if fill is not None:
29
+ yticks = ax.get_yticks()
30
+ ax.fill_between(xvalues, yvalues, "", alpha=0.08, color=color)
31
+ # Add `fill` inside the curve
32
+ # Find a single (x, y) s.t. it is inside the curve
33
+ ax.text(0.4, 0.5, fill, color=color)
34
+ ax.set_yticks(yticks)
35
+ ax.set_yticklabels([f"{y:.1f}" for y in yticks])
36
+ ax.set_ylim(-0.08, 1.08)
37
+
38
+ # Show thresholds
39
+ if show_thresholds:
40
+ for x, y, t in zip(xvalues, yvalues, thresholds):
41
+ ax.text(x + text_delta, y + text_delta, np.round(t, 2), color=color, alpha=0.5)
42
+
43
+ if show_legend:
44
+ ax.legend()
45
+
46
+ if show:
47
+ plt.show()
shared/utils/epic.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils specific for EPIC data."""
2
+ import datetime
3
+
4
+
5
+ def timestamp_to_seconds(timestamp: str):
6
+ # Parse the timestamp string into a datetime object
7
+ time_obj = datetime.datetime.strptime(timestamp, '%H:%M:%S.%f')
8
+
9
+ # Calculate the total number of seconds using the timedelta object
10
+ total_seconds = time_obj.time().second \
11
+ + time_obj.time().minute * 60 \
12
+ + time_obj.time().hour * 3600 \
13
+ + time_obj.time().microsecond / 1000000
14
+
15
+ return total_seconds
shared/utils/image.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image operations."""
2
+ from copy import deepcopy
3
+ from PIL import Image
4
+ import numpy as np
5
+
6
+
7
+ def center_crop(im: Image):
8
+ width, height = im.size
9
+ new_width = width if width < height else height
10
+ new_height = height if height < width else width
11
+
12
+ left = (width - new_width)/2
13
+ top = (height - new_height)/2
14
+ right = (width + new_width)/2
15
+ bottom = (height + new_height)/2
16
+
17
+ # Crop the center of the image
18
+ im = im.crop((left, top, right, bottom))
19
+
20
+ return im
21
+
22
+
23
+ def pad_to_square(im: Image, color=(0, 0, 0)):
24
+ im = deepcopy(im)
25
+ width, height = im.size
26
+
27
+ vert_pad = (max(width, height) - height) // 2
28
+ hor_pad = (max(width, height) - width) // 2
29
+
30
+ if len(im.mode) == 3:
31
+ color = (0, 0, 0)
32
+ elif len(im.mode) == 1:
33
+ color = 0
34
+ else:
35
+ raise ValueError(f"Image mode not supported. Image has {im.mode} channels.")
36
+
37
+ return add_margin(im, vert_pad, hor_pad, vert_pad, hor_pad, color=color)
38
+
39
+
40
+ def add_margin(pil_img, top, right, bottom, left, color=(0, 0, 0)):
41
+ """Ref: https://note.nkmk.me/en/python-pillow-add-margin-expand-canvas/"""
42
+ width, height = pil_img.size
43
+ new_width = width + right + left
44
+ new_height = height + top + bottom
45
+ result = Image.new(pil_img.mode, (new_width, new_height), color)
46
+ result.paste(pil_img, (left, top))
47
+ return result
48
+
49
+
50
+ def resize_image(image, new_height, new_width):
51
+ # Convert the numpy array image to PIL Image
52
+ pil_image = Image.fromarray(image)
53
+
54
+ # Resize the PIL Image
55
+ resized_image = pil_image.resize((new_width, new_height))
56
+
57
+ # Convert the resized PIL Image back to numpy array
58
+ resized_image_np = np.array(resized_image)
59
+
60
+ return resized_image_np
61
+
62
+
63
+ def pad_to_width(pil_image, new_width, color=(0, 0, 0)):
64
+ """Pad the image to the specified width."""
65
+ # Convert the numpy array image to PIL Image
66
+ # pil_image = Image.fromarray(image)
67
+
68
+ # Get the current width and height of the image
69
+ width, height = pil_image.size
70
+ assert new_width > width, f"New width {new_width} is less than the current width {width}."
71
+
72
+ # Calculate the padding required
73
+ hor_pad = new_width - width
74
+
75
+ # Add padding to the image
76
+ padded_image = add_margin(pil_image, 0, hor_pad, 0, 0, color=color)
77
+
78
+ # Convert the padded PIL Image back to numpy array
79
+ # padded_image_np = np.array(padded_image)
80
+
81
+ return padded_image
shared/utils/io.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for input-output loading/saving.
3
+ """
4
+
5
+ from typing import Any, List
6
+ import yaml
7
+ import pickle
8
+ import json
9
+ import pandas as pd
10
+
11
+
12
+ class PrettySafeLoader(yaml.SafeLoader):
13
+ """Custom loader for reading YAML files"""
14
+ def construct_python_tuple(self, node):
15
+ return tuple(self.construct_sequence(node))
16
+
17
+
18
+ PrettySafeLoader.add_constructor(
19
+ u'tag:yaml.org,2002:python/tuple',
20
+ PrettySafeLoader.construct_python_tuple
21
+ )
22
+
23
+
24
+ def load_yml(path: str, loader_type: str = 'default'):
25
+ """Read params from a yml file.
26
+
27
+ Args:
28
+ path (str): path to the .yml file
29
+ loader_type (str, optional): type of loader used to load yml files. Defaults to 'default'.
30
+
31
+ Returns:
32
+ Any: object (typically dict) loaded from .yml file
33
+ """
34
+ assert loader_type in ['default', 'safe']
35
+
36
+ loader = yaml.Loader if (loader_type == "default") else PrettySafeLoader
37
+
38
+ with open(path, 'r') as f:
39
+ data = yaml.load(f, Loader=loader)
40
+
41
+ return data
42
+
43
+
44
+ def save_yml(data: dict, path: str):
45
+ """Save params in the given yml file path.
46
+
47
+ Args:
48
+ data (dict): data object to save
49
+ path (str): path to .yml file to be saved
50
+ """
51
+ with open(path, 'w') as f:
52
+ yaml.dump(data, f, default_flow_style=False)
53
+
54
+
55
+ def load_pkl(path: str, encoding: str = "ascii"):
56
+ """Loads a .pkl file.
57
+
58
+ Args:
59
+ path (str): path to the .pkl file
60
+ encoding (str, optional): encoding to use for loading. Defaults to "ascii".
61
+
62
+ Returns:
63
+ Any: unpickled object
64
+ """
65
+ return pickle.load(open(path, "rb"), encoding=encoding)
66
+
67
+
68
+ def save_pkl(data: Any, path: str) -> None:
69
+ """Saves given object into .pkl file
70
+
71
+ Args:
72
+ data (Any): object to be saved
73
+ path (str): path to the location to be saved at
74
+ """
75
+ with open(path, 'wb') as f:
76
+ pickle.dump(data, f)
77
+
78
+
79
+ def load_json(path: str) -> dict:
80
+ """Helper to load json file"""
81
+ with open(path, 'rb') as f:
82
+ data = json.load(f)
83
+ return data
84
+
85
+
86
+ def save_json(data: dict, path: str):
87
+ """Helper to save `dict` as .json file."""
88
+ with open(path, 'w') as f:
89
+ json.dump(data, f)
90
+
91
+
92
+ def load_txt(path: str):
93
+ """Loads lines of a .txt file.
94
+
95
+ Args:
96
+ path (str): path to the .txt file
97
+
98
+ Returns:
99
+ List: lines of .txt file
100
+ """
101
+ with open(path) as f:
102
+ lines = f.read().splitlines()
103
+ return lines
104
+
105
+
106
+ def save_txt(data: dict, path: str):
107
+ """Writes data (lines) to a txt file.
108
+
109
+ Args:
110
+ data (dict): List of strings
111
+ path (str): path to .txt file
112
+ """
113
+ assert isinstance(data, list)
114
+
115
+ lines = "\n".join(data)
116
+ with open(path, "w") as f:
117
+ f.write(str(lines))
118
+
119
+
120
+ def read_spreadsheet(sheet_id, gid, url=None, drop_na=True, **kwargs):
121
+ if url is None:
122
+ BASE_URL = 'https://docs.google.com/spreadsheets/d/'
123
+ url = BASE_URL + sheet_id + f'/export?gid={gid}&format=csv'
124
+ df = pd.read_csv(url, **kwargs)
125
+
126
+ if drop_na:
127
+ # drop all rows which have atleast 1 NaN value
128
+ df = df.dropna(axis=0)
129
+
130
+ return df
131
+
132
+
133
+ def load_midi(file, rate=16000):
134
+ import pretty_midi
135
+ assert file.endswith('.mid')
136
+ pm = pretty_midi.PrettyMIDI(file)
137
+ y = pm.synthesize(fs=rate)
138
+ return y, rate
139
+
140
+
141
+ def load_ptz(path):
142
+ import gzip
143
+ import torch
144
+ with gzip.open(path, 'rb') as f:
145
+ data = torch.load(f)
146
+ return data
147
+
148
+
149
+ def save_video(frames, path, fps=30):
150
+ import imageio
151
+ imageio.mimwrite(path, frames, fps=fps)
shared/utils/keypoint_matching.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implements keypoint matching for a pair of images."""
2
+ import os
3
+ import numpy as np
4
+ import PIL
5
+ import cv2
6
+ import matplotlib.pyplot as plt
7
+
8
+
9
+ def show_single_image(img, figsize=(7, 5), title="Single image"):
10
+ """Displays a single image."""
11
+ fig = plt.figure(figsize=figsize)
12
+ plt.axis("off")
13
+ plt.imshow(img)
14
+ plt.title(title)
15
+ plt.show()
16
+
17
+
18
+ def show_two_images(img1, img2, title="Two images"):
19
+ """Displays a pair of images."""
20
+ fig, ax = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True)
21
+
22
+ ax[0].axis("off")
23
+ ax[0].imshow(img1)
24
+
25
+ ax[1].axis("off")
26
+ ax[1].imshow(img2)
27
+
28
+ plt.suptitle(title)
29
+ plt.show()
30
+
31
+
32
+ def show_three_images(img1, img2, img3, ax1_title="", ax2_title="", ax3_title="", title="Three images"):
33
+ """Displays a triplet of images."""
34
+ fig, ax = plt.subplots(1, 3, figsize=(15, 5), constrained_layout=True)
35
+
36
+ ax[0].axis("off")
37
+ ax[0].imshow(img1)
38
+ ax[0].set_title(ax1_title)
39
+
40
+ ax[1].axis("off")
41
+ ax[1].imshow(img2)
42
+ ax[1].set_title(ax2_title)
43
+
44
+ ax[2].axis("off")
45
+ ax[2].imshow(img3)
46
+ ax[2].set_title(ax3_title)
47
+
48
+ plt.suptitle(title)
49
+ plt.show()
50
+
51
+
52
+ class KeypointMatcher:
53
+ """Class for Keypoint matching for a pair of images."""
54
+
55
+ def __init__(self, **sift_args) -> None:
56
+ self.SIFT = cv2.SIFT_create(**sift_args)
57
+ self.BFMatcher = cv2.BFMatcher()
58
+
59
+ @staticmethod
60
+ def _check_images(img1: np.ndarray, img2: np.ndarray):
61
+ assert isinstance(img1, np.ndarray)
62
+ assert len(img1.shape) == 2
63
+
64
+ assert isinstance(img2, np.ndarray)
65
+ assert len(img2.shape) == 2
66
+
67
+ # assert img1.shape == img2.shape
68
+
69
+ @staticmethod
70
+ def _show_matches(img1, kp1, img2, kp2, matches, K=10, figsize=(10, 5), drawMatches_args=dict(matchesThickness=3, singlePointColor=(0, 0, 0))):
71
+ """Displays matches found in the image"""
72
+ selected_matches = np.random.choice(matches, K)
73
+ img3 = cv2.drawMatches(img1, kp1, img2, kp2, selected_matches, outImg=None, **drawMatches_args)
74
+ show_single_image(img3, figsize=figsize, title=f"Randomly selected K = {K} matches between the pair of images.")
75
+ return img3
76
+
77
+ def match(self, img1: PIL.Image, img2: PIL.Image, show_matches: bool = True):
78
+ """Finds, describes and matches keypoints in given pair of images."""
79
+
80
+ img1 = np.array(img1)
81
+ img1 = cv2.cvtColor(img1, cv2.COLOR_RGB2GRAY)
82
+
83
+ img2 = np.array(img2)
84
+ img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2GRAY)
85
+
86
+ # check input images
87
+ self._check_images(img1, img2)
88
+
89
+ # find kps and descriptors in each image
90
+ kp1, des1 = self.SIFT.detectAndCompute(img1, None)
91
+ kp2, des2 = self.SIFT.detectAndCompute(img2, None)
92
+
93
+ # compute matches via Brute-force matching
94
+ matches = self.BFMatcher.match(des1, des2)
95
+
96
+ # sort them in the order of their distance
97
+ matches = sorted(matches, key = lambda x:x.distance)
98
+
99
+ if show_matches:
100
+ self._show_matches(img1, kp1, img2, kp2, matches)
101
+
102
+ return matches, kp1, des1, kp2, des2
103
+
104
+
105
+ def warp(im, M, output_shape):
106
+ out = np.zeros((output_shape[0], output_shape[1]))
107
+ for i in range(output_shape[0]):
108
+ for j in range(output_shape[1]):
109
+ u, v = np.array([[i, j, 0, 0, 1, 0], [0, 0, i, j, 0, 1]]) @ M
110
+ u = int(round(u))
111
+ v = int(round(v))
112
+ if im.shape[0] > u >= 0 and im.shape[1] > v >= 0:
113
+ out[i, j] = im[u, v]
114
+
115
+ return out
116
+
117
+
118
+ def project_2d_to_6d(X: np.ndarray):
119
+ """Projects X (N x 2) to Z (2N x 6) space."""
120
+ N = len(X)
121
+ assert X.shape == (N, 2)
122
+
123
+ Z = np.zeros((2 * N, 6))
124
+ # in columns 0 to 2, fill even indexed rows of Z with X, and fill 5th column with 1
125
+ Z[::2, 0:2] = X
126
+ Z[::2, 4] = 1.0
127
+ # in columns 2 to 4, fill odd indexed rows of Z with X
128
+ Z[1::2, 2:4] = X
129
+ Z[1::2, 5] = 1.0
130
+
131
+ return Z
132
+
133
+
134
+ def project_6d_to_2d(Z: np.ndarray):
135
+ """Projects Z (2N x 6) to X (N x 2) space."""
136
+ N = len(Z) // 2
137
+ assert Z.shape == (2 * N, 6)
138
+
139
+ X_from_even_rows = Z[::2, 0:2]
140
+ X_from_odd_rows = Z[1::2, 2:4]
141
+ assert (X_from_even_rows == X_from_odd_rows).all()
142
+
143
+ return X_from_even_rows
144
+
145
+
146
+
147
+ def project_2d_to_1d(X: np.ndarray):
148
+ """Returns X (N x 2) from Z (2N, 1)"""
149
+ N = len(X)
150
+ X_stretched = np.zeros(2 * N)
151
+ X_stretched[::2] = X[:, 0]
152
+ X_stretched[1::2] = X[:, 1]
153
+ return X_stretched
154
+
155
+
156
+ def project_1d_to_2d(Z: np.ndarray):
157
+ """Returns X (N x 2) from Z (2N, 1)"""
158
+ N = len(Z) // 2
159
+ assert Z.shape == (2 * N,)
160
+
161
+ X = np.zeros((N, 2))
162
+ X[:, 0] = Z[::2]
163
+ X[:, 1] = Z[1::2]
164
+
165
+ return X
166
+
167
+
168
+ def rigid_body_transform(X: np.ndarray, params: np.ndarray):
169
+ """Performs rigid body transformation of points X (N x 2) using params (6 x 1 flattened)"""
170
+ N = len(X)
171
+ assert X.shape == (N, 2)
172
+
173
+ X = project_2d_to_6d(X)
174
+
175
+ X_transformed = np.matmul(X, params)
176
+ X_transformed = project_1d_to_2d(X_transformed)
177
+ assert X_transformed.shape == (N, 2)
178
+
179
+ return X_transformed
180
+
181
+
182
+ def rigid_body_transform_params(X1: np.ndarray, X2: np.ndarray):
183
+ """Returns rigid-body transform parameters RT (6 x 1) assuming transformation between X1 and X2"""
184
+ N = len(X1)
185
+ assert X1.shape == X2.shape
186
+ assert X1.shape == (N, 2)
187
+
188
+ # X2 = X1 * params => params = psuedoinverse(X1) * X2
189
+ X1_expanded = project_2d_to_6d(X1)
190
+ assert X1_expanded.shape == (2 * N, 6)
191
+
192
+ X2_stretched = project_2d_to_1d(X2)
193
+ assert X2_stretched.shape == (2 * N,)
194
+
195
+ params = np.dot(np.linalg.pinv(X1_expanded), X2_stretched)
196
+ return params
197
+
198
+
199
+ class ImageAlignment:
200
+ """Class to perform alignment of a pair of images given keypoints."""
201
+
202
+ def __init__(self) -> None:
203
+ pass
204
+
205
+ @staticmethod
206
+ def show_transformed_points(img1, img2, X1, kp1, kp2, matches, params, num_inliers, num_to_show=20):
207
+ import matplotlib.cm as cm
208
+
209
+ H1, W1 = img1.shape
210
+ H2, W2 = img2.shape
211
+ img = np.hstack([img1, img2])
212
+
213
+ random_matches = np.random.choice(matches, num_to_show)
214
+
215
+ fig, ax = plt.subplots(1, 1, figsize=(15, 6))
216
+ colors = cm.rainbow(np.linspace(0, 1, num_to_show))
217
+
218
+ for i, match in enumerate(random_matches):
219
+
220
+ # select a single match to visualize
221
+ x1, y1 = kp1[match.queryIdx].pt
222
+ x2, y2 = kp2[match.trainIdx].pt
223
+
224
+ # get (x1, y1) transformed to (x1_transformed, y1_transformed)
225
+ A = project_2d_to_6d(np.array([[x1, y1]]))
226
+ (x1_transformed, y1_transformed) = np.dot(A, params)
227
+
228
+ ax.imshow(img, cmap="gray")
229
+ ax.axis("off")
230
+ ax.scatter(x1_transformed + W1, y1_transformed, s=200, marker="x", color=colors[i])
231
+ ax.plot(
232
+ (x1, x1_transformed + W1), (y1, y1_transformed),
233
+ linestyle="--", color=colors[i], marker="o",
234
+ )
235
+
236
+ ax.set_title(
237
+ f"Points in image 1 mapped to transformed points estimated by {num_inliers} points.",
238
+ fontsize=18,
239
+ )
240
+
241
+ os.makedirs("./results/", exist_ok=True)
242
+ plt.savefig(f"./results/match_transformed_inliers_{num_inliers}.png", bbox_inches="tight")
243
+ plt.show()
244
+
245
+ def ransac(
246
+ self, img1, kp1, img2, kp2, matches, num_matches=6, max_iter=500,
247
+ radius_in_px=10, show_transformed=True, inlier_th_for_show=1000
248
+ ):
249
+ """Performs RANSAC to find best matches."""
250
+
251
+ best_inlier_count = 0
252
+ best_params = None
253
+
254
+ # get coordinates of all points in image 1
255
+ X1 = np.array([kp1[matches[i].queryIdx].pt for i in range(len(matches))])
256
+
257
+ # get coordinates of all points in image 2
258
+ X2 = np.array([kp2[matches[i].trainIdx].pt for i in range(len(matches))])
259
+
260
+ for i in range(max_iter):
261
+ # choose matches randomly
262
+ selected_matches = np.random.choice(matches, num_matches)
263
+
264
+ # get matched keypoints in img1
265
+ X1_selected = np.array([kp1[selected_matches[i].queryIdx].pt for i in range(len(selected_matches))])
266
+
267
+ # get matched keypoints in img2
268
+ X2_selected = np.array([kp2[selected_matches[i].trainIdx].pt for i in range(len(selected_matches))])
269
+
270
+ # get transformation parameters
271
+ params = rigid_body_transform_params(X1_selected, X2_selected)
272
+
273
+ # transform X1 to get X2_transformed
274
+ X2_transformed = rigid_body_transform(X1, params)
275
+
276
+ # find inliers
277
+ diff = np.linalg.norm(X2_transformed - X2, axis=1)
278
+ indices = diff < radius_in_px
279
+ num_inliers = sum(indices)
280
+ if num_inliers > best_inlier_count:
281
+ print(f"Found {num_inliers} inliers!")
282
+ best_params = params
283
+ best_inlier_count = num_inliers
284
+
285
+ if show_transformed and num_inliers > inlier_th_for_show:
286
+ self.show_transformed_points(img1, img2, X1, kp1, kp2, matches, best_params, num_inliers)
287
+
288
+ return best_params
289
+
290
+ def align(
291
+ self, img1, kp1, img2, kp2, matches, num_matches=6,
292
+ max_iter=500, show_warped_image=True,
293
+ save_warped=False, path="results/sample.png",
294
+ method="custom"
295
+ ):
296
+ best_params = self.ransac(img1, kp1, img2, kp2, matches, max_iter=max_iter, num_matches=num_matches)
297
+
298
+ # apply the affine transformation using cv2.warpAffine()
299
+ rows, cols = img1.shape[:2]
300
+
301
+ if method == 'custom':
302
+ img1_warped = warp(img1, best_params, (rows, cols))
303
+ else:
304
+ M = np.zeros((2, 3))
305
+ M[0, :2] = best_params[:2]
306
+ M[1, :2] = best_params[2:4]
307
+ M[0, 2] = best_params[4]
308
+ M[1, 2] = best_params[5]
309
+ img1_warped = cv2.warpAffine(img1, M, (cols, rows))
310
+
311
+ if show_warped_image:
312
+ show_three_images(
313
+ img1, img2, img1_warped, title="",
314
+ ax1_title="Image 1", ax2_title="Image 2", ax3_title="Transformation: Image 1 to Image 2",
315
+ )
316
+
317
+ if save_warped:
318
+ plt.imsave(path, img1_warped)
319
+
320
+ return best_params
321
+
322
+
323
+ if __name__ == "__main__":
324
+ # read & show images
325
+ boat1 = cv2.imread('boat1.pgm', cv2.IMREAD_GRAYSCALE)
326
+ boat2 = cv2.imread('boat2.pgm', cv2.IMREAD_GRAYSCALE)
327
+ show_two_images(boat1, boat2, title="Given pair of images.")
328
+
329
+ kp_matcher = KeypointMatcher(contrastThreshold=0.1, edgeThreshold=5)
330
+ matches, kp1, des1, kp2, des2 = kp_matcher.match(boat1, boat2, show_matches=True)
shared/utils/log.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Loggers."""
2
+ import os
3
+ from os.path import dirname, realpath, abspath
4
+ from tqdm.auto import tqdm
5
+ import numpy as np
6
+
7
+
8
+ curr_filepath = abspath(__file__)
9
+ repo_path = dirname(dirname(dirname(curr_filepath)))
10
+ # repo_path = dirname(dirname(dirname(realpath(__file__))))
11
+
12
+ def tqdm_iterator(items, desc=None, bar_format=None, **kwargs):
13
+ tqdm._instances.clear()
14
+ iterator = tqdm(
15
+ items,
16
+ desc=desc,
17
+ # bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}',
18
+ **kwargs,
19
+ )
20
+ tqdm._instances.clear()
21
+
22
+ return iterator
23
+
24
+
25
+ def print_retrieval_metrics_for_csv(metrics, scale=100):
26
+ print_string = [
27
+ np.round(scale * metrics["R1"], 3),
28
+ np.round(scale * metrics["R5"], 3),
29
+ np.round(scale * metrics["R10"], 3),
30
+ ]
31
+ if "MR" in metrics:
32
+ print_string += [metrics["MR"]]
33
+ print()
34
+ print("Final metrics: ", ",".join([str(x) for x in print_string]))
35
+ print()
36
+
37
+
38
+ def print_update(update, fillchar=":", color="yellow", pos="center"):
39
+ from termcolor import colored
40
+ # add ::: to the beginning and end of the update s.t. the total length of the
41
+ # update spans the whole terminal
42
+ try:
43
+ terminal_width = os.get_terminal_size().columns - 2
44
+ except:
45
+ terminal_width = 98
46
+ if pos == "center":
47
+ update = update.center(len(update) + 2, " ")
48
+ update = update.center(terminal_width, fillchar)
49
+ elif pos == "left":
50
+ update = update.ljust(terminal_width, fillchar)
51
+ update = update.ljust(len(update) + 2, " ")
52
+ elif pos == "right":
53
+ update = update.rjust(terminal_width, fillchar)
54
+ update = update.rjust(len(update) + 2, " ")
55
+ else:
56
+ raise ValueError("pos must be one of 'center', 'left', 'right'")
57
+ print(colored(update, color))
58
+
59
+
60
+ def json_print(data, indent=4):
61
+ import json
62
+ print(json.dumps(data, indent=indent))
63
+
64
+
65
+ def get_terminal_width():
66
+ import shutil
67
+ return shutil.get_terminal_size().columns
68
+
69
+
70
+ if __name__ == "__main__":
71
+ print("Repo path:", repo_path)
72
+
shared/utils/metrics.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helpers for metric functions"""
2
+ import numpy as np
3
+ import pandas as pd
4
+
5
+
6
+ def calculate_iou(box1, box2):
7
+ """
8
+ Calculate Intersection over Union (IoU) between two bounding boxes.
9
+
10
+ Args:
11
+ box1 (tuple): Coordinates of the first bounding box in the format (x1, y1, x2, y2).
12
+ box2 (tuple): Coordinates of the second bounding box in the format (x1, y1, x2, y2).
13
+
14
+ Returns:
15
+ float: Intersection over Union (IoU) score.
16
+ """
17
+ # Extract coordinates
18
+ x1, y1, x2, y2 = box1
19
+ x1_, y1_, x2_, y2_ = box2
20
+
21
+ # Calculate the intersection area
22
+ intersection_area = max(0, min(x2, x2_) - max(x1, x1_)) * max(0, min(y2, y2_) - max(y1, y1_))
23
+
24
+ # Calculate the areas of each bounding box
25
+ box1_area = (x2 - x1) * (y2 - y1)
26
+ box2_area = (x2_ - x1_) * (y2_ - y1_)
27
+
28
+ # Calculate IoU
29
+ iou = intersection_area / float(box1_area + box2_area - intersection_area)
30
+
31
+ return iou
32
+
33
+
34
+ def compute_intersection_1d(x, y):
35
+ # sort the boxes
36
+ x1, x2 = sorted(x)
37
+ y1, y2 = sorted(y)
38
+
39
+ # compute the intersection
40
+ intersection = max(0, min(x2, y2) - max(x1, y1))
41
+
42
+ return intersection
43
+
44
+ def compute_union_1d(x, y):
45
+ # sort the boxes
46
+ x1, x2 = sorted(x)
47
+ y1, y2 = sorted(y)
48
+
49
+ # compute the union
50
+ union = max(x2, y2) - min(x1, y1)
51
+
52
+ return union
53
+
54
+
55
+ def compute_iou_1d(pred_box, true_box):
56
+ """
57
+ Compute IoU for 1D boxes.
58
+
59
+ Args:
60
+ pred_box (float): Predicted box, [x1, x2]
61
+ true_box (float): Ground truth box, [x1, x2]
62
+
63
+ Returns:
64
+ float: IoU
65
+ """
66
+ intersection = compute_intersection_1d(pred_box, true_box)
67
+ union = compute_union_1d(pred_box, true_box)
68
+ iou = intersection / union
69
+ return iou
70
+
71
+
72
+ def compute_iou_1d_single_candidate_multiple_targets(pred_box, true_boxes):
73
+ """
74
+ Compute IoU for 1D boxes.
75
+
76
+ Args:
77
+ pred_box (float): Predicted box, [x1, x2]
78
+ true_boxes (np.ndarray): Ground truth boxes, shape: (N, 2)
79
+
80
+ Returns:
81
+ float: IoU
82
+ """
83
+ ious = []
84
+ for i, true_box in enumerate(true_boxes):
85
+ ious.append(compute_iou_1d(pred_box, true_box))
86
+ return np.array(ious)
87
+
88
+
89
+ def compute_iou_1d_multiple_candidates_multiple_targets(pred_boxes, true_boxes):
90
+ """
91
+ Compute IoU for 1D boxes.
92
+
93
+ Args:
94
+ pred_boxes (np.ndarray): Predicted boxes, shape: (N, 2)
95
+ true_boxes (np.ndarray): Ground truth boxes, shape: (N, 2)
96
+
97
+ Returns:
98
+ float: IoU
99
+ """
100
+ iou_matrix = np.zeros((len(pred_boxes), len(true_boxes)))
101
+ for i, pred_box in enumerate(pred_boxes):
102
+ for j, true_box in enumerate(true_boxes):
103
+ iou_matrix[i, j] = compute_iou_1d(pred_box, true_box)
104
+ return iou_matrix
105
+
106
+
107
+ def compute_mean_iou_1d(pred_boxes, gt_boxes, threshold=0.5):
108
+ """
109
+ Computes mean IOU for 1D bounding boxes.
110
+
111
+ Args:
112
+ pred_boxes (np.ndarray): Predicted boxes, shape: (N, 2)
113
+ gt_boxes (np.ndarray): Ground truth boxes, shape: (N, 2)
114
+ threshold (float): Threshold to consider a prediction correct
115
+
116
+ Returns:
117
+ float: Mean IOU
118
+ """
119
+ # Compute IoU for each pair of boxes
120
+ iou_matrix = np.zeros((len(pred_boxes), len(gt_boxes)))
121
+ for i, pred_box in enumerate(pred_boxes):
122
+ for j, gt_box in enumerate(gt_boxes):
123
+ iou_matrix[i, j] = compute_iou_1d(pred_box, gt_box)
124
+
125
+ # Compute the max IoU for each predicted box
126
+ max_iou_indices = np.argmax(iou_matrix, axis=1)
127
+ max_iou = iou_matrix[np.arange(len(pred_boxes)), max_iou_indices]
128
+
129
+ # For each predicted box, compute TP and FP ground truth boxes
130
+ tp = np.zeros(len(pred_boxes))
131
+ fp = np.zeros(len(pred_boxes))
132
+ iou = np.zeros(len(pred_boxes))
133
+
134
+ tp = np.where(iou_matrix >= threshold, 1, 0)
135
+ tp = max_iou >= threshold
136
+ fp = max_iou < threshold
137
+ iou = max_iou
138
+ mean_iou = np.mean(iou)
139
+ import ipdb; ipdb.set_trace()
140
+
141
+
142
+
143
+
144
+
145
+
146
+
147
+ def calculate_mAP_1d(pred_boxes, pred_scores, true_boxes, iou_thresh=0.5):
148
+ """Calculate mean average precision for 1D boxes.
149
+
150
+ Args:
151
+ pred_boxes (numpy array): Predicted boxes, shape (num_boxes,)
152
+ pred_scores (numpy array): Predicted scores, shape (num_boxes,)
153
+ true_boxes (numpy array): Ground truth boxes, shape (num_boxes,)
154
+ iou_thresh (float): IoU threshold to consider a prediction correct
155
+
156
+ Returns:
157
+ float: Mean average precision (mAP)
158
+ """
159
+ # Sort predicted boxes by score (in descending order)
160
+ sort_inds = np.argsort(pred_scores)[::-1]
161
+ pred_boxes = pred_boxes[sort_inds]
162
+ pred_scores = pred_scores[sort_inds]
163
+
164
+ # Compute true positives and false positives at each threshold
165
+ tp = np.zeros(len(pred_boxes))
166
+ fp = np.zeros(len(pred_boxes))
167
+ for i, box in enumerate(pred_boxes):
168
+ ious = np.abs(box - true_boxes) / np.maximum(1e-9, np.abs(box) + np.abs(true_boxes))
169
+ if len(ious) > 0:
170
+ max_iou_idx = np.argmax(ious)
171
+ if ious[max_iou_idx] >= iou_thresh:
172
+ if tp[max_iou_idx] == 0:
173
+ tp[i] = 1
174
+ fp[i] = 0
175
+ else:
176
+ fp[i] = 1
177
+ else:
178
+ fp[i] = 1
179
+
180
+ # Compute precision and recall at each threshold
181
+ tp_cumsum = np.cumsum(tp)
182
+ fp_cumsum = np.cumsum(fp)
183
+ recall = tp_cumsum / len(true_boxes)
184
+ precision = tp_cumsum / (tp_cumsum + fp_cumsum)
185
+
186
+ # Compute AP as area under precision-recall curve
187
+ ap = 0
188
+ for t in np.arange(0, 1.1, 0.1):
189
+ if np.sum(recall >= t) == 0:
190
+ p = 0
191
+ else:
192
+ p = np.max(precision[recall >= t])
193
+ ap += p / 11
194
+
195
+ return ap
196
+
197
+
198
+ def segment_iou(target_segment, candidate_segments):
199
+ """Compute the temporal intersection over union between a
200
+ target segment and all the test segments.
201
+ Parameters
202
+ ----------
203
+ target_segment : 1d array
204
+ Temporal target segment containing [starting, ending] times.
205
+ candidate_segments : 2d array
206
+ Temporal candidate segments containing N x [starting, ending] times.
207
+ Outputs
208
+ -------
209
+ tiou : 1d array
210
+ Temporal intersection over union score of the N's candidate segments.
211
+ """
212
+ tt1 = np.maximum(target_segment[0], candidate_segments[:, 0])
213
+ tt2 = np.minimum(target_segment[1], candidate_segments[:, 1])
214
+ # Intersection including Non-negative overlap score.
215
+ segments_intersection = (tt2 - tt1).clip(0)
216
+ # Segment union.
217
+ segments_union = (candidate_segments[:, 1] - candidate_segments[:, 0]) \
218
+ + (target_segment[1] - target_segment[0]) - segments_intersection
219
+ # Compute overlap as the ratio of the intersection
220
+ # over union of two segments.
221
+ tIoU = segments_intersection.astype(float) / segments_union
222
+ return tIoU
223
+
224
+
225
+ def interpolated_prec_rec(prec, rec):
226
+ """Interpolated AP - VOCdevkit from VOC 2011.
227
+ """
228
+ mprec = np.hstack([[0], prec, [0]])
229
+ mrec = np.hstack([[0], rec, [1]])
230
+ for i in range(len(mprec) - 1)[::-1]:
231
+ mprec[i] = max(mprec[i], mprec[i + 1])
232
+ idx = np.where(mrec[1::] != mrec[0:-1])[0] + 1
233
+ ap = np.sum((mrec[idx] - mrec[idx - 1]) * mprec[idx])
234
+ return ap
235
+
236
+
237
+ from tqdm import tqdm
238
+ def compute_average_precision_detection(
239
+ ground_truth,
240
+ prediction,
241
+ tiou_thresholds=np.linspace(0.5, 0.95, 10),
242
+ ):
243
+ """Compute average precision (detection task) between ground truth and
244
+ predictions data frames. If multiple predictions occurs for the same
245
+ predicted segment, only the one with highest score is matches as
246
+ true positive. This code is greatly inspired by Pascal VOC devkit.
247
+
248
+ Ref: https://github.com/zhang-can/CoLA/blob/\
249
+ d21f1b5a4c6c13f9715cfd4ac1ebcd065d179157/eval/eval_detection.py#L200
250
+
251
+ Parameters
252
+ ----------
253
+ ground_truth : df
254
+ Data frame containing the ground truth instances.
255
+ Required fields: ['video-id', 't-start', 't-end']
256
+ prediction : df
257
+ Data frame containing the prediction instances.
258
+ Required fields: ['video-id, 't-start', 't-end', 'score']
259
+ tiou_thresholds : 1darray, optional
260
+ Temporal intersection over union threshold.
261
+ Outputs
262
+ -------
263
+ ap : float
264
+ Average precision score.
265
+ """
266
+ ap = np.zeros(len(tiou_thresholds))
267
+ if prediction.empty:
268
+ return ap
269
+
270
+ npos = float(len(ground_truth))
271
+ lock_gt = np.ones((len(tiou_thresholds),len(ground_truth))) * -1
272
+ # Sort predictions by decreasing score order.
273
+ sort_idx = prediction['score'].values.argsort()[::-1]
274
+ prediction = prediction.loc[sort_idx].reset_index(drop=True)
275
+
276
+ # Initialize true positive and false positive vectors.
277
+ tp = np.zeros((len(tiou_thresholds), len(prediction)))
278
+ fp = np.zeros((len(tiou_thresholds), len(prediction)))
279
+
280
+ # Adaptation to query faster
281
+ ground_truth_gbvn = ground_truth.groupby('video-id')
282
+
283
+ # Assigning true positive to truly grount truth instances.
284
+ for idx, this_pred in prediction.iterrows():
285
+
286
+ try:
287
+ # Check if there is at least one ground truth in the video associated.
288
+ ground_truth_videoid = ground_truth_gbvn.get_group(this_pred['video-id'])
289
+ except Exception as e:
290
+ fp[:, idx] = 1
291
+ continue
292
+
293
+ this_gt = ground_truth_videoid.reset_index()
294
+ tiou_arr = segment_iou(this_pred[['t-start', 't-end']].values,
295
+ this_gt[['t-start', 't-end']].values)
296
+ # We would like to retrieve the predictions with highest tiou score.
297
+ tiou_sorted_idx = tiou_arr.argsort()[::-1]
298
+ for tidx, tiou_thr in enumerate(tiou_thresholds):
299
+ for jdx in tiou_sorted_idx:
300
+ if tiou_arr[jdx] < tiou_thr:
301
+ fp[tidx, idx] = 1
302
+ break
303
+ if lock_gt[tidx, this_gt.loc[jdx]['index']] >= 0:
304
+ continue
305
+ # Assign as true positive after the filters above.
306
+ tp[tidx, idx] = 1
307
+ lock_gt[tidx, this_gt.loc[jdx]['index']] = idx
308
+ break
309
+
310
+ if fp[tidx, idx] == 0 and tp[tidx, idx] == 0:
311
+ fp[tidx, idx] = 1
312
+
313
+ tp_cumsum = np.cumsum(tp, axis=1).astype(float)
314
+ fp_cumsum = np.cumsum(fp, axis=1).astype(float)
315
+ recall_cumsum = tp_cumsum / npos
316
+
317
+ precision_cumsum = tp_cumsum / (tp_cumsum + fp_cumsum)
318
+
319
+ for tidx in range(len(tiou_thresholds)):
320
+ ap[tidx] = interpolated_prec_rec(precision_cumsum[tidx,:], recall_cumsum[tidx,:])
321
+
322
+
323
+ return ap
324
+
325
+
326
+ def ap_wrapper(
327
+ true_clips,
328
+ pred_clips,
329
+ pred_scores,
330
+ tiou_thresholds=np.linspace(0.5, 0.95, 10),
331
+ ):
332
+ assert isinstance(true_clips, np.ndarray)
333
+ assert len(true_clips.shape) == 2 and true_clips.shape[1] == 2
334
+ assert isinstance(pred_clips, np.ndarray)
335
+ assert len(pred_clips.shape) == 2 and pred_clips.shape[1] == 2
336
+ assert isinstance(pred_scores, np.ndarray)
337
+ assert len(pred_scores.shape) == 1 and len(pred_scores) == pred_clips.shape[0]
338
+
339
+ true_df = pd.DataFrame(
340
+ {
341
+ "video-id": ["video1"] * len(true_clips),
342
+ "t-start": true_clips[:, 0],
343
+ "t-end": true_clips[:, 1],
344
+ }
345
+ )
346
+ pred_df = pd.DataFrame(
347
+ {
348
+ "video-id": ["video1"] * len(pred_clips),
349
+ "t-start": pred_clips[:, 0],
350
+ "t-end": pred_clips[:, 1],
351
+ "score": pred_scores,
352
+ }
353
+ )
354
+ return compute_average_precision_detection(
355
+ true_df,
356
+ pred_df,
357
+ tiou_thresholds=tiou_thresholds,
358
+ )
359
+
360
+
361
+ def nms_1d(df: pd.DataFrame, score_col="score", iou_thresh=0.5):
362
+ """Applies NMS on 1D (start, end) box predictions."""
363
+ columns = set(df.columns)
364
+ # assert columns == set(["video_id", "start", "end", "score"])
365
+ assert set(["start", "end", "video_id", score_col]).issubset(columns)
366
+ video_ids = df["video_id"].unique()
367
+
368
+ # Group by video_id
369
+ groups = df.groupby("video_id")
370
+
371
+ # Loop over videos
372
+ keep_indices = []
373
+ net_success_fraction = []
374
+ tqdm._instances.clear()
375
+ iterator = tqdm(
376
+ video_ids,
377
+ desc="Applying NMS to each video",
378
+ bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}',
379
+ )
380
+ for video_id in iterator:
381
+
382
+ # Get rows for this video
383
+ rows = groups.get_group(video_id)
384
+
385
+ # Sort by score
386
+ rows = rows.sort_values(score_col, ascending=False)
387
+
388
+ # Loop over rows until empty
389
+ n_clips = len(rows)
390
+ n_clips_selected_in_video = 0
391
+ while len(rows):
392
+
393
+ # Add top row to keep_indices
394
+ top_row = rows.iloc[0]
395
+ keep_indices.append(rows.index[0])
396
+ n_clips_selected_in_video += 1
397
+ top_row = top_row.to_dict()
398
+
399
+ top_segment = np.array([top_row["start"], top_row["end"]])
400
+ rows = rows.iloc[1:]
401
+ other_segments = rows[["start", "end"]].values
402
+ iou_values = segment_iou(top_segment, other_segments)
403
+
404
+ # Remove rows IoU > iou_thresh
405
+ rows = rows[iou_values < iou_thresh]
406
+
407
+ net_success_fraction.append(n_clips_selected_in_video / n_clips)
408
+ net_success_fraction = np.array(net_success_fraction).mean()
409
+ print("> Net success fraction: {:.2f}".format(net_success_fraction))
410
+
411
+ return keep_indices
412
+
413
+
414
+ if __name__ == "__main__":
415
+ true_clips = np.array(
416
+ [
417
+ [0.1, 0.7],
418
+ [3.4, 7.8],
419
+ [3.9, 5.4],
420
+ ]
421
+ )
422
+ pred_clips = np.array(
423
+ [
424
+ [0.2, 0.8],
425
+ [3.5, 7.9],
426
+ [3.9, 5.4],
427
+ [5.6, 6.7],
428
+ [6.0, 6.5],
429
+ ],
430
+ )
431
+ pred_scores = np.array([0.9, 0.8, 0.7, 0.6, 0.5])
432
+
433
+ # 1. Check IoU for a single pair of boxes
434
+ iou = compute_iou_1d(pred_clips[0], true_clips[0])
435
+ # Manually check that the result is correct
436
+ # Clips are [0.1, 0.7] and [0.2, 0.8]
437
+ # Intersection: [0.2, 0.7] - length = 0.5
438
+ # Union: [0.1, 0.8] - length = 0.7
439
+ # Ratio: 0.5 / 0.7 = 0.714
440
+ assert np.isclose(iou, 0.714, 3), "Incorrect IoU"
441
+
442
+ # 2. Check IoU for a single predicted box and multiple ground truth boxes
443
+ ious = compute_iou_1d_single_candidate_multiple_targets(pred_clips[0], true_clips)
444
+ assert np.allclose(ious, [0.714, 0.0, 0.0], 3), "Incorrect IoU"
445
+
446
+ # 3. Check mean IoU for multiple predicted boxes and multiple ground truth boxes
447
+ ious = compute_iou_1d_multiple_candidates_multiple_targets(pred_clips, true_clips)
448
+ assert ious.shape == (5, 3), "Incorrect shape"
449
+
450
+ ap = ap_wrapper(
451
+ true_clips,
452
+ pred_clips,
453
+ pred_scores,
454
+ tiou_thresholds=np.linspace(0.5, 0.95, 3),
455
+ )
456
+ # Take the mean of the APs across IoU thresholds
457
+ final_ap = np.mean(ap)
458
+ import ipdb; ipdb.set_trace()
shared/utils/misc.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Misc utils."""
2
+ import os
3
+ from shared.utils.log import tqdm_iterator
4
+ import numpy as np
5
+
6
+
7
+ class AttrDict(dict):
8
+ def __init__(self, *args, **kwargs):
9
+ super(AttrDict, self).__init__(*args, **kwargs)
10
+ self.__dict__ = self
11
+
12
+
13
+ def ignore_warnings(type="ignore"):
14
+ import warnings
15
+ warnings.filterwarnings(type)
16
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
17
+
18
+
19
+ def download_youtube_video(youtube_id, ext='mp4', resolution="360p", **kwargs):
20
+ import pytube
21
+ video_url = f"https://www.youtube.com/watch?v={youtube_id}"
22
+ yt = pytube.YouTube(video_url)
23
+ try:
24
+ streams = yt.streams.filter(
25
+ file_extension=ext, res=resolution, progressive=True, **kwargs,
26
+ )
27
+ # streams[0].download(output_path=save_dir, filename=f"{video_id}.{ext}")
28
+ streams[0].download(output_path='/tmp', filename='sample.mp4')
29
+ except:
30
+ print("Failed to download video: ", video_url)
31
+ return None
32
+ return "/tmp/sample.mp4"
33
+
34
+
35
+ def check_audio(video_path):
36
+ from moviepy.video.io.VideoFileClip import VideoFileClip
37
+ try:
38
+ return VideoFileClip(video_path).audio is not None
39
+ except:
40
+ return False
41
+
42
+
43
+ def check_audio_multiple(video_paths, n_jobs=8):
44
+ """Parallelly check if videos have audio"""
45
+ iterator = tqdm_iterator(video_paths, desc="Checking audio")
46
+ from joblib import Parallel, delayed
47
+ return Parallel(n_jobs=n_jobs)(
48
+ delayed(check_audio)(video_path) for video_path in iterator
49
+ )
50
+
51
+
52
+ def num_trainable_params(model, round=3, verbose=True, return_count=False):
53
+ n_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
54
+ model_name = model.__class__.__name__
55
+ if round is not None:
56
+ value = np.round(n_params / 1e6, round)
57
+ unit = "M"
58
+ else:
59
+ value = n_params
60
+ unit = ""
61
+ if verbose:
62
+ print(f"::: Number of trainable parameters in {model_name}: {value} {unit}")
63
+ if return_count:
64
+ return n_params
65
+
66
+
67
+ def num_params(model, round=3):
68
+ n_params = sum([p.numel() for p in model.parameters()])
69
+ model_name = model.__class__.__name__
70
+ if round is not None:
71
+ value = np.round(n_params / 1e6, round)
72
+ unit = "M"
73
+ else:
74
+ value = n_params
75
+ unit = ""
76
+ print(f"::: Number of total parameters in {model_name}: {value}{unit}")
77
+
78
+
79
+ def fix_seed(seed=42):
80
+ """Fix all numpy/pytorch/random seeds."""
81
+ import random
82
+ import torch
83
+ import numpy as np
84
+ random.seed(seed)
85
+ np.random.seed(seed)
86
+ torch.manual_seed(seed)
87
+ torch.cuda.manual_seed_all(seed)
88
+ torch.backends.cudnn.deterministic = True
89
+
90
+
91
+ def check_tensor(x):
92
+ print(x.shape, x.min(), x.max())
93
+
94
+
95
+ def find_nearest_indices(a, b):
96
+ """
97
+ Finds the indices of the elements in `a` that are closest to each element in `b`.
98
+
99
+ Args:
100
+ a (np.ndarray): The array to search for the closest values.
101
+ b (np.ndarray): The array of values to search for.
102
+
103
+ Returns:
104
+ np.ndarray: The indices of the closest values in `a` for each element in `b`.
105
+ """
106
+ # Reshape `a` and `b` to make use of broadcasting
107
+ a = np.array(a)
108
+ b = np.array(b)
109
+
110
+ # Calculate the absolute difference between each element in `b` and all elements in `a`
111
+ diff = np.abs(a - b[:, np.newaxis])
112
+
113
+ # Find the index of the minimum value along the second axis (which corresponds to `a`)
114
+ indices = np.argmin(diff, axis=1)
115
+
116
+ return indices
shared/utils/pandas_utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for pandas operations"""
2
+
3
+ from typing import List
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+
8
+ def apply_filters(df: pd.DataFrame, filters: dict, reset_index=False):
9
+ """
10
+ Filters df based on given filters (key-values pairs).
11
+ """
12
+ import omegaconf
13
+ X = df.copy()
14
+
15
+ all_indices = []
16
+ for col, values in filters.items():
17
+ if isinstance(values, (list, tuple, np.ndarray, omegaconf.listconfig.ListConfig)):
18
+ indices = X[col].isin(list(values))
19
+ else:
20
+ indices = X[col] == values
21
+ all_indices.append(indices)
22
+ # print(col, values, len(indices), sum(indices))
23
+ # X = X[indices]
24
+ if len(all_indices):
25
+ all_indices = np.array(all_indices)
26
+ indices = np.all(all_indices, axis=0)
27
+ X = X[indices]
28
+
29
+ if reset_index:
30
+ X = X.reset_index(drop=True)
31
+
32
+ return X
33
+
34
+
35
+ def apply_antifilters(df: pd.DataFrame, filters: dict, reset_index=False):
36
+ """
37
+ Filters df removing rows for given filters (key-values pairs).
38
+ """
39
+ X = df.copy()
40
+
41
+ for col, values in filters.items():
42
+ if isinstance(values, (list, tuple, np.ndarray)):
43
+ indices = X[col].isin(list(values))
44
+ else:
45
+ indices = X[col] == values
46
+ X = X[~indices]
47
+
48
+ if reset_index:
49
+ X = X.reset_index(drop=True)
50
+
51
+ return X
52
+
53
+
54
+ def custom_eval(x):
55
+ """Splits string '["a", "b", "c"]' into ["a", "b", "c"]."""
56
+ if isinstance(x, str):
57
+ x = x.replace('[', '')
58
+ x = x.replace(']', '')
59
+
60
+ x = x.split(',')
61
+ x = [y.rstrip().lstrip() for y in x]
62
+ return x
63
+ else:
64
+ return ['NA']
65
+
66
+
67
+ def split_column_into_columns(df, column):
68
+ """
69
+ For given df, splits `column` containing values like '["a", "b"]'
70
+ into one-hot subcolumns like a. b with `Yes`/`No` values.
71
+ """
72
+ df[column] = df[column].apply(custom_eval)
73
+
74
+ unique_values = []
75
+ for i in range(len(df)):
76
+ index = df.index[i]
77
+
78
+ list_of_values = df.loc[index, column]
79
+
80
+ for x in list_of_values:
81
+ if (x != 'NA') and (x != ''):
82
+ df.at[index, x] = 'Yes'
83
+ if x not in unique_values:
84
+ unique_values.append(x)
85
+
86
+ df[unique_values] = df[unique_values].fillna('No')
87
+ df[f'any_{column}'] = df[unique_values].apply(
88
+ lambda x: 'Yes' if 'Yes' in list(x) else 'No', axis=1
89
+ )
90
+ return df
91
+
92
+
93
+ def custom_read_csv(path: str, columns_to_onehot: List) -> pd.DataFrame:
94
+ """Custom CSV reader
95
+
96
+ Args:
97
+ path (str): path to .csv file
98
+ columns_to_onehot (List): list of columns to one-hotify
99
+
100
+ Returns:
101
+ pd.DataFrame: loaded df
102
+ """
103
+ df = pd.read_csv(path)
104
+ for column in columns_to_onehot:
105
+ df = split_column_into_columns(df, column)
106
+ return df
107
+
108
+
109
+ def split_df(df, test_size=0.2):
110
+ from sklearn.model_selection import train_test_split
111
+ # split the dataframe into train and test sets
112
+ train_df, test_df = train_test_split(df, test_size=test_size, random_state=42)
113
+
114
+ # split the train set into train and validation sets
115
+ train_df, val_df = train_test_split(train_df, test_size=test_size, random_state=42)
116
+
117
+ return train_df, val_df, test_df
shared/utils/paths.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Path utils."""
2
+ from os.path import dirname, abspath
3
+
4
+
5
+ curr_filepath = abspath(__file__)
6
+ repo_path = dirname(dirname(dirname(curr_filepath)))
7
+
8
+
9
+ def get_data_root_from_hostname():
10
+ import socket
11
+
12
+ data_root_lib = {
13
+ "diva": "/ssd/pbagad/datasets/",
14
+ "node": "/var/scratch/pbagad/datasets/",
15
+ "fs4": "/var/scratch/pbagad/datasets/",
16
+ "vggdev21": "/scratch/shared/beegfs/piyush/datasets/",
17
+ "node407": "/var/scratch/pbagad/datasets/",
18
+ "gnodee5": "/scratch/shared/beegfs/piyush/datasets/",
19
+ "gnodeg2": "/scratch/shared/beegfs/piyush/datasets/",
20
+ "gnodec2": "/scratch/shared/beegfs/piyush/datasets/",
21
+ "Piyushs-MacBook-Pro": "/Users/piyush/projects/",
22
+ "gnodec1": "/scratch/shared/beegfs/piyush/datasets/",
23
+ "gnodec5": "/scratch/shared/beegfs/piyush/datasets/",
24
+ "gnodec4": "/scratch/shared/beegfs/piyush/datasets/",
25
+ "gnoded2": "/scratch/shared/beegfs/piyush/datasets/",
26
+ }
27
+ hostname = socket.gethostname()
28
+ hostname = hostname.split(".")[0]
29
+
30
+ assert hostname in data_root_lib.keys(), \
31
+ "Hostname {} not in data_root_lib".format(hostname)
32
+
33
+ data_root = data_root_lib[hostname]
34
+ return data_root
35
+
shared/utils/physics.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ # Universal constants
6
+ C = 340. * 100. # Speed of sound in air (cm/s)
7
+
8
+
9
+ def compute_length_of_air_column_cylindrical(
10
+ timestamps, duration, height, b, **kwargs,
11
+ ):
12
+ """
13
+ Randomly chooses a l(t) curve satisfying the two point equations.
14
+ """
15
+ L = height * ( (1 - np.exp(b * (duration - timestamps))) / (1 - np.exp(b * duration)) )
16
+ return L
17
+
18
+
19
+ def compute_axial_frequency_cylindrical(
20
+ lengths, radius, beta=0.62, mode=1, **kwargs,
21
+ ):
22
+ """
23
+ Computes axial resonance frequency for cylindrical container at given timestamps.
24
+ """
25
+ if mode == 1:
26
+ harmonic_weight = 1.
27
+ elif mode == 2:
28
+ harmonic_weight = 3.
29
+ elif mode == 3:
30
+ harmonic_weight = 5.
31
+ else:
32
+ raise ValueError
33
+
34
+ # Compute fundamental frequency curve
35
+ F0 = harmonic_weight * (0.25 * C) * (1. / (lengths + (beta * radius)))
36
+
37
+ return F0
38
+
39
+
40
+ def compute_axial_frequency_bottleneck(
41
+ lengths, radius, height, Rn, Hn, beta_bottle=(0.6 + 8/np.pi), **kwargs,
42
+ ):
43
+ # Here, R and H are base radius and height of the bottleneck
44
+ eps = 1e-6
45
+ kappa = (0.5 * C / np.pi) * (Rn/radius) * np.sqrt(1 / (Hn + beta_bottle * Rn))
46
+ frequencies = kappa * np.sqrt(1 / (lengths + eps))
47
+ return frequencies
48
+
49
+
50
+ def compute_f0_cylindrical(Y, rho_g, a, R, H, mode=1, **kwargs,):
51
+
52
+ if mode == 1:
53
+ m = 1.875
54
+ n = 2
55
+ elif mode == 2:
56
+ m = 4.694
57
+ n = 3
58
+ elif mode == 3:
59
+ m = 7.855
60
+ n = 4
61
+ else:
62
+ raise ValueError
63
+
64
+ term = ( ((n**2 - 1)**2) + ((m * R/H)**4) ) / (1 + (1./n**2))
65
+ f0 = (1. / (12 * np.pi)) * np.sqrt(3 * Y / rho_g) * (a / (R**2)) * np.sqrt(term)
66
+ return f0
67
+
68
+
69
+ def compute_xi_cylindrical(rho_l, rho_g, R, a, **kwargs,):
70
+ """
71
+ Different papers use different multipliers.
72
+ For us, using 12. * (4./9.) works best empirically.
73
+ """
74
+ xi = 12. * (4. / 9.) * (rho_l/rho_g) * (R/a)
75
+ return xi
76
+
77
+
78
+ def compute_radial_frequency_cylindrical(
79
+ heights, R, H, Y, rho_g, a, rho_l, power=3, mode=1, **kwargs,
80
+ ):
81
+ """
82
+ Computes radial resonance frequency for cylindrical.
83
+
84
+ Args:
85
+ heights (np.ndarray): height of liquid at pre-defined time stamps
86
+ """
87
+ # Only f0 changes for higher modes
88
+ f0 = compute_f0_cylindrical(Y, rho_g, a, R, H, mode=mode)
89
+ xi = compute_xi_cylindrical(rho_l, rho_g, R, a)
90
+ frequencies = f0 / np.sqrt(1 + xi * ((heights/H) ** power) )
91
+ return frequencies
92
+
93
+
94
+ def compute_slant_lengths_semiconical(
95
+ timestamps, duration, r_top, r_bot, height, **kwargs,
96
+ ):
97
+
98
+ # Top radius / base radius
99
+ rf = r_bot / r_top
100
+
101
+ # Time fraction
102
+ tf = timestamps/duration
103
+
104
+ # Height fractions: h(t) / H
105
+ height_fractions = (1. / (rf - 1)) * (np.cbrt(((rf**3 - 1) * (tf)) + 1) - 1)
106
+
107
+ # Slant air column lengths
108
+ heights = height_fractions * height
109
+ slant_lengths = np.sqrt(1 - ((r_top - r_bot) / height)**2) * (height - heights)
110
+
111
+ return slant_lengths
112
+
113
+
114
+ def compute_axial_frequency_semiconical(slant_lengths, r_top, r_bot, beta=1.28, **kwargs):
115
+ """
116
+ Computes axial resonance frequency for cylinder.
117
+
118
+ Args:
119
+ slant_lengths (np.ndarray): slant length of air column
120
+ r_top (float): top radius
121
+ r_bot (float): base radius
122
+ beta (float): end correction coefficient
123
+ """
124
+ frequencies_axial = (C / 2) * (1 / (slant_lengths + (beta * (r_bot + r_top))))
125
+ return frequencies_axial
126
+
127
+
128
+ def get_frequencies(
129
+ t,
130
+ params,
131
+ container_shape="cylindrical",
132
+ harmonic=None,
133
+ vibration_type="axial",
134
+ semiconical_as_cylinder=False,
135
+ ):
136
+ """
137
+ Computes requires frequency f(t) for given t.
138
+ """
139
+
140
+ if container_shape == "semiconical":
141
+ # Makes an assumption that semiconical shape is similar to cylindrical
142
+ if semiconical_as_cylinder:
143
+ container_shape = "cylindrical"
144
+
145
+ if (container_shape == "cylindrical") or (container_shape == "bottleneck_as_cylindrical"):
146
+
147
+ # Compute length of air column first
148
+ lengths = compute_length_of_air_column_cylindrical(t, **params)
149
+
150
+ if vibration_type == "axial":
151
+ frequencies = compute_axial_frequency_cylindrical(lengths, **params)
152
+
153
+ if harmonic is not None:
154
+ assert harmonic > 0 and isinstance(harmonic, int)
155
+ frequencies = frequencies * harmonic
156
+
157
+ elif vibration_type == "radial":
158
+ if harmonic is None:
159
+ mode = 1
160
+ else:
161
+ assert isinstance(harmonic, int)
162
+ assert harmonic in [1, 2]
163
+ mode = harmonic + 1
164
+ frequencies = compute_radial_frequency_cylindrical(
165
+ lengths, mode=mode, **params,
166
+ )
167
+
168
+ else:
169
+ raise NotImplementedError
170
+
171
+ elif container_shape == "semiconical":
172
+
173
+ # Compute length of air column first
174
+ slant_lengths = compute_slant_lengths_semiconical(t, **params)
175
+
176
+ if vibration_type == "axial":
177
+ frequencies = compute_axial_frequency_semiconical(
178
+ slant_lengths, **params,
179
+ )
180
+
181
+ if harmonic is not None:
182
+ assert harmonic > 0 and isinstance(harmonic, int)
183
+ frequencies = frequencies * harmonic
184
+
185
+ else:
186
+ raise NotImplementedError
187
+
188
+ elif container_shape == "bottleneck":
189
+
190
+ # Compute length of air column first assuming
191
+ # base of the bottle is a cylindrical
192
+ lengths = compute_length_of_air_column_cylindrical(t, **params)
193
+
194
+ if vibration_type == "axial":
195
+ frequencies = compute_axial_frequency_bottleneck(
196
+ lengths, **params,
197
+ )
198
+
199
+ if harmonic is not None:
200
+ assert harmonic > 0 and isinstance(harmonic, int)
201
+ frequencies = frequencies * harmonic
202
+ else:
203
+ raise NotImplementedError
204
+
205
+ else:
206
+ raise ValueError
207
+
208
+ return frequencies
209
+
210
+
211
+ def get_params(row, semiconical_as_cylinder=False):
212
+ m = row["measurements"]
213
+ duration = row["end_time"] - row["start_time"]
214
+ params = dict(duration=duration)
215
+ if row["shape"] == "cylindrical":
216
+ radius = 0.25 * (m["diameter_top"] + m["diameter_bottom"])
217
+ height = m["net_height"]
218
+ params.update(
219
+ height=height,
220
+ radius=radius,
221
+ beta=row.get("beta", 0.62),
222
+ # Constant flow
223
+ b=0.01,
224
+ )
225
+ elif row["shape"] == "semiconical":
226
+
227
+ if semiconical_as_cylinder:
228
+ # Assume semiconical shape as cylindrical
229
+ radius = 0.25 * (m["diameter_top"] + m["diameter_bottom"])
230
+ height = m["net_height"]
231
+ params.update(
232
+ height=height,
233
+ radius=radius,
234
+ beta=0.62,
235
+ # Constant flow
236
+ b=0.01,
237
+ )
238
+ else:
239
+ r_top = 0.5 * m["diameter_top"]
240
+ r_bot = 0.5 * m["diameter_bottom"]
241
+ height = m["net_height"]
242
+ beta = 1.28
243
+ params.update(
244
+ r_top=r_top,
245
+ r_bot=r_bot,
246
+ height=height,
247
+ beta=beta,
248
+ )
249
+ elif row["shape"] == "bottleneck":
250
+ radius = 0.5 * m["diameter_bottom"]
251
+ Rn = 0.5 * m["diameter_top"]
252
+ Hn = m["neck_height"]
253
+ height = m["net_height"] - Hn
254
+ params.update(
255
+ height=height,
256
+ radius=radius,
257
+ Rn=Rn,
258
+ Hn=Hn,
259
+ # Constant flow
260
+ b=0.01,
261
+ )
262
+ elif row["shape"] == "bottleneck_as_cylindrical":
263
+ # Approximates bottleneck as cylindrical
264
+ radius = 0.5 * m["diameter_bottom"]
265
+ height = m["net_height"] + m["neck_height"]
266
+ params.update(
267
+ height=height,
268
+ radius=radius,
269
+ beta=row.get("beta", 0.62),
270
+ # Constant flow
271
+ b=0.01,
272
+ )
273
+ else:
274
+ raise ValueError
275
+ return params
276
+
277
+ def frequency_to_wavelength(f):
278
+ """
279
+ Converts frequency to wavelength.
280
+
281
+ Args:
282
+ f (float): frequency
283
+ """
284
+ return C / f
285
+
286
+
287
+ def wavelength_to_frequency(l):
288
+ """
289
+ Converts wavelength to frequency.
290
+
291
+ Args:
292
+ l (float): wavelength
293
+ """
294
+ return C / l
295
+
296
+
297
+ def get_cylinder_radius(m):
298
+ return 0.25 * (m['diameter_top'] + m['diameter_bottom'])
299
+
300
+
301
+ def get_cylinder_height(m):
302
+ return m['net_height']
303
+
304
+
305
+ def get_flow_rate(m, duration):
306
+ r = get_cylinder_radius(m)
307
+ h = get_cylinder_height(m)
308
+ volume = np.pi * (r**2) * h
309
+ q = volume / duration
310
+ return q
311
+
312
+
313
+ def get_length_of_air_column(m, duration, timestamps):
314
+ h = get_cylinder_height(m)
315
+ l = (-h/duration) * timestamps + h
316
+ l = torch.from_numpy(l)
317
+ return l
318
+
319
+
320
+ def estimate_cylinder_radius(wavelengths, timestamps=None, beta=0.62):
321
+ radius_pred = ((1. / beta) * (wavelengths[-1] / 4.)).item()
322
+ return radius_pred
323
+
324
+
325
+ def estimate_cylinder_height(wavelengths, timestamps=None, beta=0.62):
326
+ height_pred = wavelengths[0] / 4. - wavelengths[-1] / 4.
327
+ return height_pred.item()
328
+
329
+
330
+ def estimate_flow_rate(wavelengths, timestamps=None, output_fps=49.):
331
+ radius = estimate_cylinder_radius(wavelengths)
332
+ l_pred = (wavelengths - wavelengths[-1]) / 4.
333
+ slope = np.gradient(l_pred).mean() * output_fps
334
+ Q_pred = -np.pi * (radius**2) * slope
335
+ return Q_pred
336
+
337
+
338
+ def estimate_length_of_air_column(wavelengths, timestamps=None):
339
+ l_pred = (wavelengths - wavelengths[-1]) / 4.
340
+ return l_pred
341
+
shared/utils/text_basic.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utils for processing and encoding text."""
2
+
3
+ import torch
4
+
5
+
6
+
7
+ def lemmatize_verbs(verbs: list):
8
+ from nltk.stem import WordNetLemmatizer
9
+ wnl = WordNetLemmatizer()
10
+ return [wnl.lemmatize(verb, 'v') for verb in verbs]
11
+
12
+
13
+ def lemmatize_adverbs(adverbs: list):
14
+ from nltk.stem import WordNetLemmatizer
15
+ wnl = WordNetLemmatizer()
16
+ return [wnl.lemmatize(adverb, 'r') for adverb in adverbs]
17
+
18
+
19
+ class SentenceEncoder:
20
+
21
+ def __init__(self, model_name="roberta-base"):
22
+ from transformers import RobertaTokenizer, RobertaModel
23
+ if model_name == 'roberta-base':
24
+ self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
25
+ self.model = RobertaModel.from_pretrained(model_name)
26
+
27
+ def encode_sentence(self, sentence):
28
+ inputs = self.tokenizer.encode_plus(
29
+ sentence, add_special_tokens=True, return_tensors='pt',
30
+ )
31
+ with torch.no_grad():
32
+ outputs = self.model(**inputs)
33
+ # sentence_embedding = torch.mean(outputs.last_hidden_state, dim=1).squeeze(0)
34
+ sentence_embedding = outputs.last_hidden_state[:, 0, :]
35
+ return sentence_embedding
36
+
37
+ def encode_sentences(self, sentences):
38
+ """Encodes a list of sentences using model."""
39
+ tokenized_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
40
+ with torch.no_grad():
41
+ outputs = self.model(**tokenized_input)
42
+ embeddings = outputs.last_hidden_state[:, 0, :]
43
+ return embeddings
44
+
shared/utils/visualize.py ADDED
@@ -0,0 +1,2208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helpers for visualization"""
2
+ import os
3
+ import numpy as np
4
+ import matplotlib
5
+ import matplotlib.pyplot as plt
6
+ import cv2
7
+ import PIL
8
+ from PIL import Image, ImageOps, ImageDraw
9
+ from os.path import exists
10
+ import librosa.display
11
+ import pandas as pd
12
+ import itertools
13
+ import librosa
14
+ from tqdm import tqdm
15
+ from IPython.display import Audio, Markdown, display
16
+ from ipywidgets import Button, HBox, VBox, Text, Label, HTML, widgets
17
+ from shared.utils.log import tqdm_iterator
18
+
19
+ import warnings
20
+ warnings.filterwarnings("ignore")
21
+
22
+ try:
23
+ import torchvideotransforms
24
+ except:
25
+ print("Failed to import torchvideotransforms. Proceeding without.")
26
+ print("Please install using:")
27
+ print("pip install git+https://github.com/hassony2/torch_videovision")
28
+
29
+
30
+ # define predominanat colors
31
+ COLORS = {
32
+ "pink": (242, 116, 223),
33
+ "cyan": (46, 242, 203),
34
+ "red": (255, 0, 0),
35
+ "green": (0, 255, 0),
36
+ "blue": (0, 0, 255),
37
+ "yellow": (255, 255, 0),
38
+ }
39
+
40
+
41
+ def get_predominant_color(color_key, mode="RGB", alpha=0):
42
+ assert color_key in COLORS.keys(), f"Unknown color key: {color_key}"
43
+ if mode == "RGB":
44
+ return COLORS[color_key]
45
+ elif mode == "RGBA":
46
+ return COLORS[color_key] + (alpha,)
47
+
48
+
49
+ def show_single_image(image: np.ndarray, figsize: tuple = (8, 8), title: str = None, cmap: str = None, ticks=False):
50
+ """Show a single image."""
51
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
52
+
53
+ if isinstance(image, Image.Image):
54
+ image = np.asarray(image)
55
+
56
+ ax.set_title(title)
57
+ ax.imshow(image, cmap=cmap)
58
+
59
+ if not ticks:
60
+ ax.set_xticks([])
61
+ ax.set_yticks([])
62
+
63
+ plt.show()
64
+
65
+
66
+ def show_grid_of_images(
67
+ images: np.ndarray, n_cols: int = 4, figsize: tuple = (8, 8), subtitlesize=14,
68
+ cmap=None, subtitles=None, title=None, save=False, savepath="sample.png", titlesize=20,
69
+ ysuptitle=0.8, xlabels=None, sizealpha=0.7, show=True, row_labels=None, aspect=None,
70
+ ):
71
+ """Show a grid of images."""
72
+ n_cols = min(n_cols, len(images))
73
+
74
+ copy_of_images = images.copy()
75
+ for i, image in enumerate(copy_of_images):
76
+ if isinstance(image, Image.Image):
77
+ image = np.asarray(image)
78
+ copy_of_images[i] = image
79
+
80
+ if subtitles is None:
81
+ subtitles = [None] * len(images)
82
+
83
+ if xlabels is None:
84
+ xlabels = [None] * len(images)
85
+
86
+ if row_labels is None:
87
+ num_rows = int(np.ceil(len(images) / n_cols))
88
+ row_labels = [None] * num_rows
89
+
90
+ n_rows = int(np.ceil(len(images) / n_cols))
91
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
92
+ if len(images) == 1:
93
+ axes = np.array([[axes]])
94
+ for i, ax in enumerate(axes.flat):
95
+ if i < len(copy_of_images):
96
+ if len(copy_of_images[i].shape) == 2 and cmap is None:
97
+ cmap="gray"
98
+ ax.imshow(copy_of_images[i], cmap=cmap, aspect=aspect)
99
+ ax.set_title(subtitles[i], fontsize=subtitlesize)
100
+ ax.set_xlabel(xlabels[i], fontsize=sizealpha * subtitlesize)
101
+ ax.set_xticks([])
102
+ ax.set_yticks([])
103
+
104
+ col_idx = i % n_cols
105
+ if col_idx == 0:
106
+ ax.set_ylabel(row_labels[i // n_cols], fontsize=sizealpha * subtitlesize)
107
+
108
+ fig.tight_layout()
109
+ plt.suptitle(title, y=ysuptitle, fontsize=titlesize)
110
+ if save:
111
+ plt.savefig(savepath, bbox_inches='tight')
112
+ if show:
113
+ plt.show()
114
+
115
+
116
+
117
+ def add_text_to_image(image, text):
118
+ from PIL import ImageFont
119
+ from PIL import ImageDraw
120
+
121
+ # # resize image
122
+ # image = image.resize((image.size[0] * 2, image.size[1] * 2))
123
+
124
+ draw = ImageDraw.Draw(image)
125
+ font = ImageFont.load_default()
126
+ # font = ImageFont.load("arial.pil")
127
+ # font = ImageFont.FreeTypeFont(size=20)
128
+ # font = ImageFont.truetype("arial.ttf", 28, encoding="unic")
129
+
130
+ # change fontsize
131
+
132
+ # select color = black if image is mostly white
133
+ if np.mean(image) > 200:
134
+ draw.text((0, 0), text, (0,0,0), font=font)
135
+ else:
136
+ draw.text((0, 0), text, (255,255,255), font=font)
137
+
138
+ # draw.text((0, 0), text, (255,255,255), font=font)
139
+ return image
140
+
141
+
142
+ def show_keypoint_matches(
143
+ img1, kp1, img2, kp2, matches,
144
+ K=10, figsize=(10, 5), drawMatches_args=dict(matchesThickness=3, singlePointColor=(0, 0, 0)),
145
+ choose_matches="random",
146
+ ):
147
+ """Displays matches found in the pair of images"""
148
+ if choose_matches == "random":
149
+ selected_matches = np.random.choice(matches, K)
150
+ elif choose_matches == "all":
151
+ K = len(matches)
152
+ selected_matches = matches
153
+ elif choose_matches == "topk":
154
+ selected_matches = matches[:K]
155
+ else:
156
+ raise ValueError(f"Unknown value for choose_matches: {choose_matches}")
157
+
158
+ # color each match with a different color
159
+ cmap = matplotlib.cm.get_cmap('gist_rainbow', K)
160
+ colors = [[int(x*255) for x in cmap(i)[:3]] for i in np.arange(0,K)]
161
+ drawMatches_args.update({"matchColor": -1, "singlePointColor": (100, 100, 100)})
162
+
163
+ img3 = cv2.drawMatches(img1, kp1, img2, kp2, selected_matches, outImg=None, **drawMatches_args)
164
+ show_single_image(
165
+ img3,
166
+ figsize=figsize,
167
+ title=f"[{choose_matches.upper()}] Selected K = {K} matches between the pair of images.",
168
+ )
169
+ return img3
170
+
171
+
172
+ def draw_kps_on_image(image: np.ndarray, kps: np.ndarray, color=COLORS["red"], radius=3, thickness=-1, return_as="PIL"):
173
+ """
174
+ Draw keypoints on image.
175
+
176
+ Args:
177
+ image: Image to draw keypoints on.
178
+ kps: Keypoints to draw. Note these should be in (x, y) format.
179
+ """
180
+ if isinstance(image, Image.Image):
181
+ image = np.asarray(image)
182
+ if isinstance(color, str):
183
+ color = PIL.ImageColor.getrgb(color)
184
+ colors = [color] * len(kps)
185
+ elif isinstance(color, tuple):
186
+ colors = [color] * len(kps)
187
+ elif isinstance(color, list):
188
+ colors = [PIL.ImageColor.getrgb(c) for c in color]
189
+ assert len(colors) == len(kps), f"Number of colors ({len(colors)}) must be equal to number of keypoints ({len(kps)})"
190
+
191
+ for kp, c in zip(kps, colors):
192
+ image = cv2.circle(
193
+ image.copy(), (int(kp[0]), int(kp[1])), radius=radius, color=c, thickness=thickness)
194
+
195
+ if return_as == "PIL":
196
+ return Image.fromarray(image)
197
+
198
+ return image
199
+
200
+
201
+ def get_concat_h(im1, im2):
202
+ """Concatenate two images horizontally"""
203
+ dst = Image.new('RGB', (im1.width + im2.width, im1.height))
204
+ dst.paste(im1, (0, 0))
205
+ dst.paste(im2, (im1.width, 0))
206
+ return dst
207
+
208
+
209
+ def get_concat_v(im1, im2):
210
+ """Concatenate two images vertically"""
211
+ dst = Image.new('RGB', (im1.width, im1.height + im2.height))
212
+ dst.paste(im1, (0, 0))
213
+ dst.paste(im2, (0, im1.height))
214
+ return dst
215
+
216
+
217
+ def show_images_with_keypoints(images: list, kps: list, radius=15, color=(0, 220, 220), figsize=(10, 8)):
218
+ assert len(images) == len(kps)
219
+
220
+ # generate
221
+ images_with_kps = []
222
+ for i in range(len(images)):
223
+ img_with_kps = draw_kps_on_image(images[i], kps[i], radius=radius, color=color, return_as="PIL")
224
+ images_with_kps.append(img_with_kps)
225
+
226
+ # show
227
+ show_grid_of_images(images_with_kps, n_cols=len(images), figsize=figsize)
228
+
229
+
230
+ def set_latex_fonts(usetex=True, fontsize=14, show_sample=False, **kwargs):
231
+ try:
232
+ plt.rcParams.update({
233
+ "text.usetex": usetex,
234
+ "font.family": "serif",
235
+ # "font.serif": ["Computer Modern Romans"],
236
+ "font.size": fontsize,
237
+ **kwargs,
238
+ })
239
+ if show_sample:
240
+ plt.figure()
241
+ plt.title("Sample $y = x^2$")
242
+ plt.plot(np.arange(0, 10), np.arange(0, 10)**2, "--o")
243
+ plt.grid()
244
+ plt.show()
245
+ except:
246
+ print("Failed to setup LaTeX fonts. Proceeding without.")
247
+ pass
248
+
249
+
250
+
251
+ def plot_2d_points(
252
+ list_of_points_2d,
253
+ colors=None,
254
+ sizes=None,
255
+ markers=None,
256
+ alpha=0.75,
257
+ h=256,
258
+ w=256,
259
+ ax=None,
260
+ save=True,
261
+ savepath="test.png",
262
+ ):
263
+
264
+ if ax is None:
265
+ fig, ax = plt.subplots(1, 1)
266
+ ax.set_xlim([0, w])
267
+ ax.set_ylim([0, h])
268
+
269
+ if sizes is None:
270
+ sizes = [0.1 for _ in range(len(list_of_points_2d))]
271
+ if colors is None:
272
+ colors = ["gray" for _ in range(len(list_of_points_2d))]
273
+ if markers is None:
274
+ markers = ["o" for _ in range(len(list_of_points_2d))]
275
+
276
+ for points_2d, color, s, m in zip(list_of_points_2d, colors, sizes, markers):
277
+ ax.scatter(points_2d[:, 0], points_2d[:, 1], s=s, alpha=alpha, color=color, marker=m)
278
+
279
+ if save:
280
+ plt.savefig(savepath, bbox_inches='tight')
281
+
282
+
283
+ def plot_2d_points_on_image(
284
+ image,
285
+ img_alpha=1.0,
286
+ ax=None,
287
+ list_of_points_2d=[],
288
+ scatter_args=dict(),
289
+ ):
290
+ if ax is None:
291
+ fig, ax = plt.subplots(1, 1)
292
+ ax.imshow(image, alpha=img_alpha)
293
+ scatter_args["save"] = False
294
+ plot_2d_points(list_of_points_2d, ax=ax, **scatter_args)
295
+
296
+ # invert the axis
297
+ ax.set_ylim(ax.get_ylim()[::-1])
298
+
299
+
300
+ def compare_landmarks(
301
+ image, ground_truth_landmarks, v2d, predicted_landmarks,
302
+ save=False, savepath="compare_landmarks.png", num_kps_to_show=-1,
303
+ show_matches=True,
304
+ ):
305
+
306
+ # show GT landmarks on image
307
+ fig, axes = plt.subplots(1, 3, figsize=(11, 4))
308
+ ax = axes[0]
309
+ plot_2d_points_on_image(
310
+ image,
311
+ list_of_points_2d=[ground_truth_landmarks],
312
+ scatter_args=dict(sizes=[15], colors=["limegreen"]),
313
+ ax=ax,
314
+ )
315
+ ax.set_title("GT landmarks", fontsize=12)
316
+
317
+ # since the projected points are inverted, using 180 degree rotation about z-axis
318
+ ax = axes[1]
319
+ plot_2d_points_on_image(
320
+ image,
321
+ list_of_points_2d=[v2d, predicted_landmarks],
322
+ scatter_args=dict(sizes=[0.08, 15], markers=["o", "x"], colors=["royalblue", "red"]),
323
+ ax=ax,
324
+ )
325
+ ax.set_title("Projection of predicted mesh", fontsize=12)
326
+
327
+ # plot the ground truth and predicted landmarks on the same image
328
+ ax = axes[2]
329
+ plot_2d_points_on_image(
330
+ image,
331
+ list_of_points_2d=[
332
+ ground_truth_landmarks[:num_kps_to_show],
333
+ predicted_landmarks[:num_kps_to_show],
334
+ ],
335
+ scatter_args=dict(sizes=[15, 15], markers=["o", "x"], colors=["limegreen", "red"]),
336
+ ax=ax,
337
+ img_alpha=0.5,
338
+ )
339
+ ax.set_title("GT and predicted landmarks", fontsize=12)
340
+
341
+ if show_matches:
342
+ for i in range(num_kps_to_show):
343
+ x_values = [ground_truth_landmarks[i, 0], predicted_landmarks[i, 0]]
344
+ y_values = [ground_truth_landmarks[i, 1], predicted_landmarks[i, 1]]
345
+ ax.plot(x_values, y_values, color="yellow", markersize=1, linewidth=2.)
346
+
347
+ fig.tight_layout()
348
+ if save:
349
+ plt.savefig(savepath, bbox_inches="tight")
350
+
351
+
352
+
353
+ def plot_historgam_values(
354
+ X, display_vals=False,
355
+ bins=50, figsize=(8, 5),
356
+ show_mean=True,
357
+ xlabel=None, ylabel=None,
358
+ ax=None, title=None, show=False,
359
+ **kwargs,
360
+ ):
361
+ if ax is None:
362
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
363
+
364
+ ax.hist(X, bins=bins, **kwargs)
365
+ if title is None:
366
+ title = "Histogram of values"
367
+
368
+ ax.set_xlabel(xlabel)
369
+ ax.set_ylabel(ylabel)
370
+
371
+ if display_vals:
372
+ x, counts = np.unique(X, return_counts=True)
373
+ # sort_indices = np.argsort(x)
374
+ # x = x[sort_indices]
375
+ # counts = counts[sort_indices]
376
+ # for i in range(len(x)):
377
+ # ax.text(x[i], counts[i], counts[i], ha='center', va='bottom')
378
+
379
+ ax.grid(alpha=0.3)
380
+
381
+ if show_mean:
382
+ mean = np.mean(X)
383
+ mean_string = f"$\mu$: {mean:.2f}"
384
+ ax.set_title(title + f" ({mean_string}) ")
385
+ else:
386
+ ax.set_title(title)
387
+
388
+ if not show:
389
+ return ax
390
+ else:
391
+ plt.show()
392
+
393
+
394
+ """Helper functions for all kinds of 2D/3D visualization"""
395
+ def bokeh_2d_scatter(x, y, desc, figsize=(700, 700), colors=None, use_nb=False, title="Bokeh scatter plot"):
396
+ import matplotlib.colors as mcolors
397
+ from bokeh.plotting import figure, output_file, show, ColumnDataSource
398
+ from bokeh.models import HoverTool
399
+ from bokeh.io import output_notebook
400
+
401
+ if use_nb:
402
+ output_notebook()
403
+
404
+ # define colors to be assigned
405
+ if colors is None:
406
+ # applies the same color
407
+ # create a color iterator: pick a random color and apply it to all points
408
+ # colors = [np.random.choice(itertools.cycle(palette))] * len(x)
409
+ colors = [np.random.choice(["red", "green", "blue", "yellow", "pink", "black", "gray"])] * len(x)
410
+
411
+ # # applies different colors
412
+ # colors = np.array([ [r, g, 150] for r, g in zip(50 + 2*x, 30 + 2*y) ], dtype="uint8")
413
+
414
+
415
+ # define the df of data to plot
416
+ source = ColumnDataSource(
417
+ data=dict(
418
+ x=x,
419
+ y=y,
420
+ desc=desc,
421
+ color=colors,
422
+ )
423
+ )
424
+
425
+ # define the attributes to show on hover
426
+ hover = HoverTool(
427
+ tooltips=[
428
+ ("index", "$index"),
429
+ ("(x, y)", "($x, $y)"),
430
+ ("Desc", "@desc"),
431
+ ]
432
+ )
433
+
434
+ p = figure(
435
+ plot_width=figsize[0], plot_height=figsize[1], tools=[hover], title=title,
436
+ )
437
+ p.circle('x', 'y', size=10, source=source, fill_color="color")
438
+ show(p)
439
+
440
+
441
+
442
+
443
+ def bokeh_2d_scatter_new(
444
+ df, x, y, hue, label, color_column=None, size_col=None,
445
+ figsize=(700, 700), use_nb=False, title="Bokeh scatter plot",
446
+ legend_loc="bottom_left", edge_color="black", audio_col=None,
447
+ ):
448
+ from bokeh.plotting import figure, output_file, show, ColumnDataSource
449
+ from bokeh.models import HoverTool
450
+ from bokeh.io import output_notebook
451
+
452
+ if use_nb:
453
+ output_notebook()
454
+
455
+ assert {x, y, hue, label}.issubset(set(df.keys()))
456
+
457
+ if isinstance(color_column, str) and color_column in df.keys():
458
+ color_column_name = color_column
459
+ else:
460
+ import matplotlib.colors as mcolors
461
+ colors = list(mcolors.BASE_COLORS.keys()) + list(mcolors.TABLEAU_COLORS.values())
462
+ # colors = list(mcolors.BASE_COLORS.keys())
463
+ colors = itertools.cycle(np.unique(colors))
464
+
465
+ hue_to_color = dict()
466
+ unique_hues = np.unique(df[hue].values)
467
+ for _hue in unique_hues:
468
+ hue_to_color[_hue] = next(colors)
469
+ df["color"] = df[hue].apply(lambda k: hue_to_color[k])
470
+ color_column_name = "color"
471
+
472
+ if size_col is not None:
473
+ assert isinstance(size_col, str) and size_col in df.keys()
474
+ else:
475
+ sizes = [10.] * len(df)
476
+ df["size"] = sizes
477
+ size_col = "size"
478
+
479
+ source = ColumnDataSource(
480
+ dict(
481
+ x = df[x].values,
482
+ y = df[y].values,
483
+ hue = df[hue].values,
484
+ label = df[label].values,
485
+ color = df[color_column_name].values,
486
+ edge_color = [edge_color] * len(df),
487
+ sizes = df[size_col].values,
488
+ )
489
+ )
490
+
491
+ # define the attributes to show on hover
492
+ hover = HoverTool(
493
+ tooltips=[
494
+ ("index", "$index"),
495
+ ("(x, y)", "($x, $y)"),
496
+ ("Desc", "@label"),
497
+ ("Cluster", "@hue"),
498
+ ]
499
+ )
500
+
501
+ p = figure(
502
+ plot_width=figsize[0],
503
+ plot_height=figsize[1],
504
+ tools=["pan","wheel_zoom","box_zoom","save","reset","help"] + [hover],
505
+ title=title,
506
+ )
507
+ p.circle(
508
+ 'x', 'y', size="sizes",
509
+ source=source, fill_color="color",
510
+ legend_group="hue", line_color="edge_color",
511
+ )
512
+ p.legend.location = legend_loc
513
+ p.legend.click_policy="hide"
514
+
515
+
516
+ show(p)
517
+
518
+
519
+ import torch
520
+ def get_sentence_embedding(model, tokenizer, sentence):
521
+ encoded = tokenizer.encode_plus(sentence, return_tensors="pt")
522
+
523
+ with torch.no_grad():
524
+ output = model(**encoded)
525
+
526
+ last_hidden_state = output.last_hidden_state
527
+ assert last_hidden_state.shape[0] == 1
528
+ assert last_hidden_state.shape[-1] == 768
529
+
530
+ # only pick the [CLS] token embedding (sentence embedding)
531
+ sentence_embedding = last_hidden_state[0, 0]
532
+
533
+ return sentence_embedding
534
+
535
+
536
+ def lighten_color(color, amount=0.5):
537
+ """
538
+ Lightens the given color by multiplying (1-luminosity) by the given amount.
539
+ Input can be matplotlib color string, hex string, or RGB tuple.
540
+
541
+ Examples:
542
+ >> lighten_color('g', 0.3)
543
+ >> lighten_color('#F034A3', 0.6)
544
+ >> lighten_color((.3,.55,.1), 0.5)
545
+ """
546
+ import matplotlib.colors as mc
547
+ import colorsys
548
+ try:
549
+ c = mc.cnames[color]
550
+ except:
551
+ c = color
552
+ c = colorsys.rgb_to_hls(*mc.to_rgb(c))
553
+ return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])
554
+
555
+
556
+ def plot_histogram(df, col, ax=None, color="blue", title=None, xlabel=None, **kwargs):
557
+ if ax is None:
558
+ fig, ax = plt.subplots(1, 1, figsize=(5, 4))
559
+ ax.grid(alpha=0.3)
560
+ xlabel = col if xlabel is None else xlabel
561
+ ax.set_xlabel(xlabel)
562
+ ax.set_ylabel("Frequency")
563
+ title = f"Historgam of {col}" if title is None else title
564
+ ax.set_title(title)
565
+ label = f"Mean: {np.round(df[col].mean(), 1)}"
566
+ ax.hist(df[col].values, density=False, color=color, edgecolor=lighten_color(color, 0.1), label=label, **kwargs)
567
+ if "bins" in kwargs:
568
+ xticks = list(np.arange(kwargs["bins"])[::5])
569
+ xticks += list(np.linspace(xticks[-1], int(df[col].max()), 5, dtype=int))
570
+ # print(xticks)
571
+ ax.set_xticks(xticks)
572
+ ax.legend()
573
+ plt.show()
574
+
575
+
576
+ def beautify_ax(ax, title=None, titlesize=20, sizealpha=0.7, xlabel=None, ylabel=None):
577
+ labelsize = sizealpha * titlesize
578
+ ax.grid(alpha=0.3)
579
+ ax.set_xlabel(xlabel, fontsize=labelsize)
580
+ ax.set_ylabel(ylabel, fontsize=labelsize)
581
+ ax.set_title(title, fontsize=titlesize)
582
+
583
+
584
+
585
+
586
+ def get_text_features(text: list, model, device, batch_size=16):
587
+ import clip
588
+ text_batches = [text[i:i+batch_size] for i in range(0, len(text), batch_size)]
589
+ text_features = []
590
+ model = model.to(device)
591
+ model = model.eval()
592
+ for batch in tqdm(text_batches, desc="Getting text features", bar_format="{l_bar}{bar:20}{r_bar}"):
593
+ batch = clip.tokenize(batch).to(device)
594
+ with torch.no_grad():
595
+ batch_features = model.encode_text(batch)
596
+ text_features.append(batch_features.cpu().numpy())
597
+ text_features = np.concatenate(text_features, axis=0)
598
+ return text_features
599
+
600
+
601
+ from sklearn.manifold import TSNE
602
+ def reduce_dim(X, perplexity=30, n_iter=1000):
603
+ tsne = TSNE(
604
+ n_components=2,
605
+ perplexity=perplexity,
606
+ n_iter=n_iter,
607
+ init='pca',
608
+ # learning_rate="auto",
609
+ )
610
+ Z = tsne.fit_transform(X)
611
+ return Z
612
+
613
+
614
+ from IPython.display import Video
615
+ def show_video(video_path):
616
+ """Show a video in a Jupyter notebook"""
617
+ assert exists(video_path), f"Video path {video_path} does not exist"
618
+
619
+ # display the video in a Jupyter notebook
620
+ return Video(video_path, embed=True, width=480)
621
+ # Video(video_path, embed=True, width=600, height=400)
622
+ # html_attributes="controls autoplay loop muted"
623
+
624
+
625
+
626
+
627
+ def show_single_audio(filepath=None, data=None, rate=None, start=None, end=None, label="Sample audio"):
628
+
629
+ if filepath is None:
630
+ assert data is not None and rate is not None, "Either filepath or data and rate must be provided"
631
+ args = dict(data=data, rate=rate)
632
+ else:
633
+ assert data is None and rate is None, "Either filepath or data and rate must be provided"
634
+ data, rate = librosa.load(filepath)
635
+ # args = dict(filename=filepath)
636
+ args = dict(data=data, rate=rate)
637
+
638
+ if start is not None and end is not None:
639
+ start = max(int(start * rate), 0)
640
+ end = min(int(end * rate), len(data))
641
+ else:
642
+ start = 0
643
+ end = len(data)
644
+ data = data[start:end]
645
+ args["data"] = data
646
+
647
+ if label is None:
648
+ label = "Sample audio"
649
+
650
+ label = Label(f"{label}")
651
+ out = widgets.Output()
652
+ with out:
653
+ display(Audio(**args))
654
+ vbox = VBox([label, out])
655
+ return vbox
656
+
657
+
658
+ def show_single_audio_with_spectrogram(filepath=None, data=None, rate=None, label="Sample audio", figsize=(6, 2)):
659
+
660
+ if filepath is None:
661
+ assert data is not None and rate is not None, "Either filepath or data and rate must be provided"
662
+ else:
663
+ data, rate = librosa.load(filepath)
664
+
665
+ # Show audio
666
+ vbox = show_single_audio(data=data, rate=rate, label=label)
667
+ # get width of audio widget
668
+ width = vbox.children[1].layout.width
669
+
670
+ # Show spectrogram
671
+ spec_out = widgets.Output()
672
+ D = librosa.stft(data) # STFT of y
673
+ S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
674
+ with spec_out:
675
+ fig, ax = plt.subplots(figsize=figsize)
676
+ img = librosa.display.specshow(
677
+ S_db,
678
+ ax=ax,
679
+ x_axis='time',
680
+ # y_axis='linear',
681
+ )
682
+ # img = widgets.Image.from_file(fig)
683
+ # import ipdb; ipdb.set_trace()
684
+ # img = widgets.Image(img)
685
+ # add image to vbox
686
+ vbox.children += (spec_out,)
687
+ return vbox
688
+
689
+ def show_spectrogram(audio_path=None, data=None, rate=None, figsize=(6, 2), ax=None, show=True):
690
+ if data is None and rate is None:
691
+ # Show spectrogram
692
+ data, rate = librosa.load(audio_path)
693
+ else:
694
+ assert audio_path is None, "Either audio_path or data and rate must be provided"
695
+
696
+ hop_length = 512
697
+ D = librosa.stft(data, n_fft=2048, hop_length=hop_length, win_length=2048) # STFT of y
698
+ S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
699
+
700
+ # Create spectrogram plot widget
701
+ if ax is None:
702
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
703
+ im = ax.imshow(S_db, origin='lower', aspect='auto', cmap='inferno')
704
+
705
+ # Replace xtixks with time
706
+ xticks = ax.get_xticks()
707
+ time_in_seconds = librosa.frames_to_time(xticks, sr=rate, hop_length=hop_length)
708
+ ax.set_xticklabels(np.round(time_in_seconds, 1))
709
+ ax.set_xlabel('Time')
710
+ ax.set_yticks([])
711
+ if ax is None:
712
+ plt.close(fig)
713
+
714
+ # Create widget output
715
+ spec_out = widgets.Output()
716
+ with spec_out:
717
+ display(fig)
718
+ return spec_out
719
+
720
+
721
+ def show_single_video_and_spectrogram(
722
+ video_path, audio_path,
723
+ label="Sample video", figsize=(6, 2),
724
+ width=480,
725
+ show_spec_stats=False,
726
+ ):
727
+ # Show video
728
+ vbox = show_single_video(video_path, label=label, width=width)
729
+ # get width of video widget
730
+ width = vbox.children[1].layout.width
731
+
732
+ # Show spectrogram
733
+ data, rate = librosa.load(audio_path)
734
+ hop_length = 512
735
+ D = librosa.stft(data, n_fft=2048, hop_length=hop_length, win_length=2048) # STFT of y
736
+ S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
737
+
738
+ # Create spectrogram plot widget
739
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
740
+ im = ax.imshow(S_db, origin='lower', aspect='auto', cmap='inferno')
741
+
742
+ # Replace xtixks with time
743
+ xticks = ax.get_xticks()
744
+ time_in_seconds = librosa.frames_to_time(xticks, sr=rate, hop_length=hop_length)
745
+ ax.set_xticklabels(np.round(time_in_seconds, 1))
746
+ ax.set_xlabel('Time')
747
+ ax.set_yticks([])
748
+ plt.close(fig)
749
+
750
+ # Create widget output
751
+ spec_out = widgets.Output()
752
+ with spec_out:
753
+ display(fig)
754
+ vbox.children += (spec_out,)
755
+
756
+ if show_spec_stats:
757
+ # Compute mean of spectrogram over frequency axis
758
+ eps = 1e-5
759
+ S_db_normalized = (S_db - S_db.mean(axis=1)[:, None]) / (S_db.std(axis=1)[:, None] + eps)
760
+ S_db_over_time = S_db_normalized.sum(axis=0)
761
+ # Plot S_db_over_time
762
+ fig, ax = plt.subplots(1, 1, figsize=(6, 2))
763
+ # ax.set_title("Spectrogram over time")
764
+ ax.grid(alpha=0.5)
765
+ x = np.arange(len(S_db_over_time))
766
+ x = librosa.frames_to_time(x, sr=rate, hop_length=hop_length)
767
+ x = np.round(x, 1)
768
+ ax.plot(x, S_db_over_time)
769
+ ax.set_xlabel('Time')
770
+ ax.set_yticks([])
771
+ plt.close(fig)
772
+ plot_out = widgets.Output()
773
+ with plot_out:
774
+ display(fig)
775
+ vbox.children += (plot_out,)
776
+
777
+ return vbox
778
+
779
+
780
+ def show_single_spectrogram(
781
+ filepath=None,
782
+ data=None,
783
+ rate=None,
784
+ start=None,
785
+ end=None,
786
+ ax=None,
787
+ label="Sample spectrogram",
788
+ figsize=(6, 2),
789
+ xlabel="Time",
790
+ ):
791
+
792
+ if filepath is None:
793
+ assert data is not None and rate is not None, "Either filepath or data and rate must be provided"
794
+ else:
795
+ rate = 22050
796
+ offset = start or 0
797
+ clip_duration = end - start if end is not None else None
798
+ data, rate = librosa.load(filepath, sr=rate, offset=offset, duration=clip_duration)
799
+
800
+ # start = 0 if start is None else int(rate * start)
801
+ # end = len(data) if end is None else int(rate * end)
802
+ # data = data[start:end]
803
+
804
+ # Show spectrogram
805
+ spec_out = widgets.Output()
806
+ D = librosa.stft(data) # STFT of y
807
+ S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
808
+
809
+ if ax is None:
810
+ fig, ax = plt.subplots(figsize=figsize)
811
+
812
+ with spec_out:
813
+ img = librosa.display.specshow(
814
+ S_db,
815
+ ax=ax,
816
+ x_axis='time',
817
+ sr=rate,
818
+ # y_axis='linear',
819
+ )
820
+ ax.set_xlabel(xlabel)
821
+ ax.margins(x=0)
822
+ plt.subplots_adjust(wspace=0, hspace=0)
823
+
824
+ # img = widgets.Image.from_file(fig)
825
+ # import ipdb; ipdb.set_trace()
826
+ # img = widgets.Image(img)
827
+ # add image to vbox
828
+ vbox = VBox([spec_out])
829
+ return vbox
830
+ # return spec_out
831
+
832
+
833
+ # from decord import VideoReader
834
+ def show_single_video(filepath, label="Sample video", width=480, fix_resolution=True):
835
+
836
+ if label is None:
837
+ label = "Sample video"
838
+
839
+ height = None
840
+ if fix_resolution:
841
+ aspect_ratio = 16. / 9.
842
+ height = int(width * (1/ aspect_ratio))
843
+
844
+ label = Label(f"{label}")
845
+ out = widgets.Output()
846
+ with out:
847
+ display(Video(filepath, embed=True, width=width, height=height))
848
+ vbox = VBox([label, out])
849
+ return vbox
850
+
851
+
852
+ def show_grid_of_audio(files, starts=None, ends=None, labels=None, ncols=None, show_spec=False):
853
+
854
+ for f in files:
855
+ assert os.path.exists(f), f"File {f} does not exist."
856
+
857
+ if labels is None:
858
+ labels = [None] * len(files)
859
+
860
+ if starts is None:
861
+ starts = [None] * len(files)
862
+
863
+ if ends is None:
864
+ ends = [None] * len(files)
865
+
866
+ assert len(files) == len(labels)
867
+
868
+ if ncols is None:
869
+ ncols = 3
870
+ nfiles = len(files)
871
+ nrows = nfiles // ncols + (nfiles % ncols != 0)
872
+ # print(nrows, ncols)
873
+
874
+ for i in range(nrows):
875
+ row_hbox = []
876
+ for j in range(ncols):
877
+ idx = i * ncols + j
878
+ # print(i, j, idx)
879
+
880
+ if idx < len(files):
881
+ file, label = files[idx], labels[idx]
882
+ start, end = starts[idx], ends[idx]
883
+ vbox = show_single_audio(
884
+ filepath=file, label=label, start=start, end=end
885
+ )
886
+ if show_spec:
887
+ spec_box = show_spectrogram(file, figsize=(3.6, 1))
888
+ # Add spectrogram to vbox
889
+ vbox.children += (spec_box,)
890
+
891
+ # if not show_spec:
892
+ # vbox = show_single_audio(
893
+ # filepath=file, label=label, start=start, end=end
894
+ # )
895
+ # else:
896
+ # vbox = show_single_audio_with_spectrogram(
897
+ # filepath=file, label=label
898
+ # )
899
+ row_hbox.append(vbox)
900
+ row_hbox = HBox(row_hbox)
901
+ display(row_hbox)
902
+
903
+
904
+ def show_grid_of_videos(
905
+ files,
906
+ cut=False,
907
+ starts=None,
908
+ ends=None,
909
+ labels=None,
910
+ ncols=None,
911
+ width_overflow=False,
912
+ show_spec=False,
913
+ width_of_screen=1000,
914
+ ):
915
+ from moviepy.editor import VideoFileClip
916
+
917
+ for f in files:
918
+ assert os.path.exists(f), f"File {f} does not exist."
919
+
920
+ if labels is None:
921
+ labels = [None] * len(files)
922
+ if starts is not None and ends is not None:
923
+ cut = True
924
+ if starts is None:
925
+ starts = [None] * len(files)
926
+ if ends is None:
927
+ ends = [None] * len(files)
928
+
929
+ assert len(files) == len(labels) == len(starts) == len(ends)
930
+
931
+ # cut the videos to the specified duration
932
+ if cut:
933
+ cut_files = []
934
+ for i, f in enumerate(files):
935
+ start, end = starts[i], ends[i]
936
+
937
+ tmp_f = os.path.join(os.path.expanduser("~"), f"tmp/clip_{i}.mp4")
938
+ cut_files.append(tmp_f)
939
+
940
+ video = VideoFileClip(f)
941
+ start = 0 if start is None else start
942
+ end = video.duration-1 if end is None else end
943
+ # print(start, end)
944
+ video.subclip(start, end).write_videofile(tmp_f, logger=None, verbose=False)
945
+ files = cut_files
946
+
947
+ if ncols is None:
948
+ ncols = 3
949
+ width_of_screen = 1000
950
+
951
+ # get width of the whole display screen
952
+ if not width_overflow:
953
+ width_of_single_video = width_of_screen // ncols
954
+ else:
955
+ width_of_single_video = 280
956
+
957
+ nfiles = len(files)
958
+ nrows = nfiles // ncols + (nfiles % ncols != 0)
959
+ # print(nrows, ncols)
960
+
961
+ for i in range(nrows):
962
+ row_hbox = []
963
+ for j in range(ncols):
964
+ idx = i * ncols + j
965
+ # print(i, j, idx)
966
+
967
+ if idx < len(files):
968
+ file, label = files[idx], labels[idx]
969
+ if not show_spec:
970
+ vbox = show_single_video(file, label, width_of_single_video)
971
+ else:
972
+ vbox = show_single_video_and_spectrogram(file, file, width=width_of_single_video, label=label)
973
+ row_hbox.append(vbox)
974
+ row_hbox = HBox(row_hbox)
975
+ display(row_hbox)
976
+
977
+
978
+
979
+ def preview_video(fp, label="Sample video frames", mode="uniform", frames_to_show=6):
980
+ from decord import VideoReader
981
+
982
+ assert exists(fp), f"Video does not exist at {fp}"
983
+ vr = VideoReader(fp)
984
+
985
+ nfs = len(vr)
986
+ fps = vr.get_avg_fps()
987
+ dur = nfs / fps
988
+
989
+ if mode == "all":
990
+ frame_indices = np.arange(nfs)
991
+ elif mode == "uniform":
992
+ frame_indices = np.linspace(0, nfs - 1, frames_to_show, dtype=int)
993
+ elif mode == "random":
994
+ frame_indices = np.random.randint(0, nfs - 1, replace=False)
995
+ frame_indices = sorted(frame_indices)
996
+ else:
997
+ raise ValueError(f"Unknown frame viewing mode {mode}.")
998
+
999
+ # Show grid of image
1000
+ images = vr.get_batch(frame_indices).asnumpy()
1001
+ show_grid_of_images(images, n_cols=len(frame_indices), title=label, figsize=(12, 2.3), titlesize=10)
1002
+
1003
+
1004
+ def preview_multiple_videos(fps, labels, mode="uniform", frames_to_show=6):
1005
+ for fp in fps:
1006
+ assert exists(fp), f"Video does not exist at {fp}"
1007
+
1008
+ for fp, label in zip(fps, labels):
1009
+ preview_video(fp, label, mode=mode, frames_to_show=frames_to_show)
1010
+
1011
+
1012
+
1013
+ def show_small_clips_in_a_video(
1014
+ video_path,
1015
+ clip_segments: list,
1016
+ width=360,
1017
+ labels=None,
1018
+ show_spec=False,
1019
+ resize=False,
1020
+ ):
1021
+ from moviepy.editor import VideoFileClip
1022
+ from ipywidgets import Layout
1023
+
1024
+ video = VideoFileClip(video_path)
1025
+
1026
+ if resize:
1027
+ # Resize the video
1028
+ print("Resizing the video to width", width)
1029
+ video = video.resize(width=width)
1030
+
1031
+ if labels is None:
1032
+ labels = [
1033
+ f"Clip {i+1} [{clip_segments[i][0]} : {clip_segments[i][1]}]" for i in range(len(clip_segments))
1034
+ ]
1035
+ else:
1036
+ assert len(labels) == len(clip_segments)
1037
+
1038
+ tmp_dir = os.path.join(os.path.expanduser("~"), "tmp")
1039
+ tmp_clippaths = [f"{tmp_dir}/clip_{i}.mp4" for i in range(len(clip_segments))]
1040
+
1041
+ iterator = tqdm_iterator(zip(clip_segments, tmp_clippaths), total=len(clip_segments), desc="Preparing clips")
1042
+ clips = [
1043
+ video.subclip(x, y).write_videofile(f, logger=None, verbose=False) \
1044
+ for (x, y), f in iterator
1045
+ ]
1046
+ # show_grid_of_videos(tmp_clippaths, labels, ncols=len(clips), width_overflow=True)
1047
+ hbox = []
1048
+ for i in range(len(clips)):
1049
+ # vbox = show_single_video(tmp_clippaths[i], labels[i], width=280)
1050
+
1051
+ vbox = widgets.Output()
1052
+ with vbox:
1053
+ if show_spec:
1054
+ display(
1055
+ show_single_video_and_spectrogram(
1056
+ tmp_clippaths[i], tmp_clippaths[i],
1057
+ width=width, figsize=(4.4, 1.5),
1058
+ )
1059
+ )
1060
+ else:
1061
+ display(Video(tmp_clippaths[i], embed=True, width=width))
1062
+ # reduce vspace between video and label
1063
+ display(Label(labels[i], layout=Layout(margin="-8px 0px 0px 0px")))
1064
+ # if show_spec:
1065
+ # display(show_single_spectrogram(tmp_clippaths[i], figsize=(4.5, 1.5)))
1066
+ hbox.append(vbox)
1067
+ hbox = HBox(hbox)
1068
+ display(hbox)
1069
+
1070
+
1071
+ def show_single_video_and_audio(
1072
+ video_path, audio_path, label="Sample video and audio",
1073
+ start=None, end=None, width=360, sr=44100, show=True,
1074
+ ):
1075
+ from moviepy.editor import VideoFileClip
1076
+
1077
+ # Load video
1078
+ video = VideoFileClip(video_path)
1079
+ video_args = {"embed": True, "width": width}
1080
+ filepath = video_path
1081
+
1082
+ # Load audio
1083
+ audio_waveform, sr = librosa.load(audio_path, sr=sr)
1084
+ audio_args = {"data": audio_waveform, "rate": sr}
1085
+
1086
+ if start is not None and end is not None:
1087
+
1088
+ # Cut video from start to end
1089
+ tmp_dir = os.path.join(os.path.expanduser("~"), "tmp")
1090
+ clip_path = os.path.join(tmp_dir, "clip_sample.mp4")
1091
+ video.subclip(start, end).write_videofile(clip_path, logger=None, verbose=False)
1092
+ filepath = clip_path
1093
+
1094
+ # Cut audio from start to end
1095
+ audio_waveform = audio_waveform[int(start * sr): int(end * sr)]
1096
+ audio_args["data"] = audio_waveform
1097
+
1098
+ out = widgets.Output()
1099
+ with out:
1100
+ label = f"{label} [{start} : {end}]"
1101
+ display(Label(label))
1102
+ display(Video(filepath, **video_args))
1103
+ display(Audio(**audio_args))
1104
+
1105
+ if show:
1106
+ display(out)
1107
+ else:
1108
+ return out
1109
+
1110
+
1111
+ def plot_waveform(waveform, sample_rate, figsize=(10, 2), ax=None, skip=100, show=True, title=None):
1112
+ if isinstance(waveform, torch.Tensor):
1113
+ waveform = waveform.numpy()
1114
+
1115
+ time_axis = torch.arange(0, len(waveform)) / sample_rate
1116
+ waveform = waveform[::skip]
1117
+ time_axis = time_axis[::skip]
1118
+
1119
+ if len(waveform.shape) == 1:
1120
+ num_channels = 1
1121
+ num_frames = waveform.shape[0]
1122
+ waveform = waveform.reshape(1, num_frames)
1123
+ elif len(waveform.shape) == 2:
1124
+ num_channels, num_frames = waveform.shape
1125
+ else:
1126
+ raise ValueError(f"Waveform has invalid shape {waveform.shape}")
1127
+
1128
+ if ax is None:
1129
+ figure, axes = plt.subplots(num_channels, 1, figsize=figsize)
1130
+ if num_channels == 1:
1131
+ axes = [axes]
1132
+ for c in range(num_channels):
1133
+ axes[c].plot(time_axis, waveform[c], linewidth=1)
1134
+ axes[c].grid(True)
1135
+ if num_channels > 1:
1136
+ axes[c].set_ylabel(f"Channel {c+1}")
1137
+ figure.suptitle(title)
1138
+ else:
1139
+ assert num_channels == 1
1140
+ ax.plot(time_axis, waveform[0], linewidth=1)
1141
+ ax.grid(True)
1142
+ # ax.set_xticks([])
1143
+ # ax.set_yticks([])
1144
+ # ax.set_xlim(-0.1, 0.1)
1145
+ ax.set_ylim(-0.05, 0.05)
1146
+
1147
+ if show:
1148
+ plt.show(block=False)
1149
+
1150
+
1151
+ def show_waveform_as_image(waveform, sr=16000):
1152
+ """Plots a waveform as plt fig and converts into PIL.Image"""
1153
+ fig, ax = plt.subplots(figsize=(10, 2))
1154
+ plot_waveform(waveform, sr, ax=ax, show=False)
1155
+ fig.canvas.draw()
1156
+ img = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
1157
+ plt.close(fig)
1158
+ return img
1159
+
1160
+
1161
+ def plot_raw_audio_signal_with_markings(signal: np.ndarray, markings: list,
1162
+ title: str = 'Raw audio signal with markings',
1163
+ figsize: tuple = (23, 4),
1164
+ ):
1165
+
1166
+ plt.figure(figsize=figsize)
1167
+ plt.grid()
1168
+
1169
+ plt.plot(signal)
1170
+ for value in markings:
1171
+ plt.axvline(x=value, c='red')
1172
+ plt.xlabel('Time')
1173
+ plt.title(title)
1174
+
1175
+ plt.show()
1176
+ plt.close()
1177
+
1178
+
1179
+ def get_concat_h(im1, im2):
1180
+ """Concatenate two images horizontally"""
1181
+ dst = Image.new('RGB', (im1.width + im2.width, im1.height))
1182
+ dst.paste(im1, (0, 0))
1183
+ dst.paste(im2, (im1.width, 0))
1184
+ return dst
1185
+
1186
+
1187
+ def concat_images(images):
1188
+ im1 = images[0]
1189
+ canvas_height = max([im.height for im in images])
1190
+ dst = Image.new('RGB', (sum([im.width for im in images]), im1.height))
1191
+ start_width = 0
1192
+ for i, im in enumerate(images):
1193
+ if im.height < canvas_height:
1194
+ start_height = (canvas_height - im.height) // 2
1195
+ else:
1196
+ start_height = 0
1197
+ print(i, start_height)
1198
+ dst.paste(im, (start_width, start_height))
1199
+ start_width += im.width
1200
+ return dst
1201
+
1202
+
1203
+ def concat_images_with_border(images, border_width=5, border_color="white"):
1204
+ im1 = images[0]
1205
+ total_width = sum([im.width for im in images]) + (len(images) - 1) * border_width
1206
+ max_height = max([im.height for im in images])
1207
+ dst = Image.new(
1208
+ 'RGB',
1209
+ (total_width, max_height),
1210
+ border_color,
1211
+ )
1212
+ start_width = 0
1213
+ uniform_height = im1.height
1214
+ canvas_height = max([im.height for im in images])
1215
+ for i, im in enumerate(images):
1216
+ # if im.height != uniform_height:
1217
+ # im = resize_height(im.copy(), uniform_height)
1218
+ if im.height < canvas_height:
1219
+ start_height = (canvas_height - im.height) // 2
1220
+
1221
+ # Pad with zeros at top and bottom
1222
+ im = ImageOps.expand(
1223
+ im, border=(0, start_height, 0, canvas_height - im.height - start_height),
1224
+ )
1225
+ start_height = 0
1226
+ else:
1227
+ start_height = 0
1228
+ dst.paste(im, (start_width, start_height))
1229
+ start_width += im.width + border_width
1230
+ return dst
1231
+
1232
+
1233
+ def concat_images_vertically(images):
1234
+ im1 = images[0]
1235
+ dst = Image.new('RGB', (im1.width, sum([im.height for im in images])))
1236
+ start_height = 0
1237
+ for i, im in enumerate(images):
1238
+ dst.paste(im, (0, start_height))
1239
+ start_height += im.height
1240
+ return dst
1241
+
1242
+
1243
+ def concat_images_vertically_with_border(images, border_width=5, border_color="white"):
1244
+ im1 = images[0]
1245
+ dst = Image.new('RGB', (im1.width, sum([im.height for im in images]) + (len(images) - 1) * border_width), border_color)
1246
+ start_height = 0
1247
+ for i, im in enumerate(images):
1248
+ dst.paste(im, (0, start_height))
1249
+ start_height += im.height + border_width
1250
+ return dst
1251
+
1252
+
1253
+ def get_concat_v(im1, im2):
1254
+ """Concatenate two images vertically"""
1255
+ dst = Image.new('RGB', (im1.width, im1.height + im2.height))
1256
+ dst.paste(im1, (0, 0))
1257
+ dst.paste(im2, (0, im1.height))
1258
+ return dst
1259
+
1260
+
1261
+ def set_latex_fonts(usetex=True, fontsize=14, show_sample=False, **kwargs):
1262
+ try:
1263
+ plt.rcParams.update({
1264
+ "text.usetex": usetex,
1265
+ "font.family": "serif",
1266
+ "font.serif": ["Computer Modern Roman"],
1267
+ "font.size": fontsize,
1268
+ **kwargs,
1269
+ })
1270
+ if show_sample:
1271
+ plt.figure()
1272
+ plt.title("Sample $y = x^2$")
1273
+ plt.plot(np.arange(0, 10), np.arange(0, 10)**2, "--o")
1274
+ plt.grid()
1275
+ plt.show()
1276
+ except:
1277
+ print("Failed to setup LaTeX fonts. Proceeding without.")
1278
+ pass
1279
+
1280
+
1281
+ def get_colors(num_colors, palette="jet"):
1282
+ cmap = plt.get_cmap(palette)
1283
+ colors = [cmap(i) for i in np.linspace(0, 1, num_colors)]
1284
+ return colors
1285
+
1286
+
1287
+ def add_box_on_image(image, bbox, color="red", thickness=3, resized=False, fillcolor=None, fillalpha=0.2):
1288
+ """
1289
+ Adds bounding box on image.
1290
+
1291
+ Args:
1292
+ image (PIL.Image): image
1293
+ bbox (list): [xmin, ymin, xmax, ymax]
1294
+ color: -
1295
+ thickness: -
1296
+ """
1297
+ image = image.copy().convert("RGB")
1298
+ # color = get_predominant_color(color)
1299
+ color = PIL.ImageColor.getrgb(color)
1300
+
1301
+ # Apply alpha to fillcolor
1302
+ if fillcolor is not None:
1303
+ if isinstance(fillcolor, str):
1304
+ fillcolor = PIL.ImageColor.getrgb(fillcolor)
1305
+ fillcolor= fillcolor + (int(fillalpha * 255),)
1306
+ elif isinstance(fillcolor, tuple):
1307
+ if len(fillcolor) == 3:
1308
+ fillcolor= fillcolor + (int(fillalpha * 255),)
1309
+ else:
1310
+ pass
1311
+
1312
+ # Create an instance of the ImageDraw class
1313
+ draw = ImageDraw.Draw(image, "RGBA")
1314
+
1315
+ # Draw the bounding box on the image
1316
+ draw.rectangle(bbox, outline=color, width=thickness, fill=fillcolor)
1317
+
1318
+ # Resize
1319
+ new_width, new_height = (320, 240)
1320
+ if resized:
1321
+ image = image.resize((new_width, new_height))
1322
+
1323
+ return image
1324
+
1325
+
1326
+ def add_multiple_boxes_on_image(image, bboxes, colors=None, thickness=3, resized=False, fillcolor=None, fillalpha=0.2):
1327
+ image = image.copy().convert("RGB")
1328
+ if colors is None:
1329
+ colors = ["red"] * len(bboxes)
1330
+ for bbox, color in zip(bboxes, colors):
1331
+ image = add_box_on_image(image, bbox, color, thickness, resized, fillcolor, fillalpha)
1332
+ return image
1333
+
1334
+
1335
+ def colorize_mask(mask, color="red"):
1336
+ # mask = mask.convert("RGBA")
1337
+ color = PIL.ImageColor.getrgb(color)
1338
+ mask = ImageOps.colorize(mask, (0, 0, 0, 0), color)
1339
+ return mask
1340
+
1341
+
1342
+ def add_mask_on_image(image: Image, mask: Image, color="green", alpha=0.5):
1343
+ image = image.copy()
1344
+ mask = mask.copy()
1345
+
1346
+ # get color if it is a string
1347
+ if isinstance(color, str):
1348
+ color = PIL.ImageColor.getrgb(color)
1349
+ # color = get_predominant_color(color)
1350
+ mask = ImageOps.colorize(mask, (0, 0, 0, 0), color)
1351
+
1352
+ mask = mask.convert("RGB")
1353
+ assert (mask.size == image.size)
1354
+ assert (mask.mode == image.mode)
1355
+
1356
+ # Blend the original image and the segmentation mask with a 50% weight
1357
+ blended_image = Image.blend(image, mask, alpha)
1358
+ return blended_image
1359
+
1360
+
1361
+ def blend_images(img1, img2, alpha=0.5):
1362
+ # Convert images to RGBA
1363
+ img1 = img1.convert("RGBA")
1364
+ img2 = img2.convert("RGBA")
1365
+ alpha_blended = Image.blend(img1, img2, alpha=alpha)
1366
+ # Convert back to RGB
1367
+ alpha_blended = alpha_blended.convert("RGB")
1368
+ return alpha_blended
1369
+
1370
+
1371
+ def visualize_youtube_clip(
1372
+ youtube_id, st, et, label="",
1373
+ show_spec=False,
1374
+ video_width=360, video_height=240,
1375
+ ):
1376
+
1377
+ url = f"https://www.youtube.com/embed/{youtube_id}?start={int(st)}&end={int(et)}"
1378
+ video_html_code = f"""
1379
+ <iframe height="{video_height}" width="{video_width}" src="{url}" frameborder="0" allowfullscreen></iframe>
1380
+ """
1381
+ label_html_code = f"""<b>Caption</b>: {label} <br> <b>Time</b>: {st} to {et}"""
1382
+
1383
+ # Show label and video below it
1384
+ label = widgets.HTML(label_html_code)
1385
+ video = widgets.HTML(video_html_code)
1386
+
1387
+ if show_spec:
1388
+ import pytube
1389
+ import base64
1390
+ from io import BytesIO
1391
+ from moviepy.video.io.VideoFileClip import VideoFileClip
1392
+ from moviepy.audio.io.AudioFileClip import AudioFileClip
1393
+
1394
+ # Load audio directly from youtube
1395
+ video_url = f"https://www.youtube.com/watch?v={youtube_id}"
1396
+ yt = pytube.YouTube(video_url)
1397
+ # Get the audio stream
1398
+ audio_stream = yt.streams.filter(only_audio=True).first()
1399
+
1400
+ # Download audio stream
1401
+ # audio_file = os.path.join("/tmp", "sample_audio.mp3")
1402
+ audio_stream.download(output_path='/tmp', filename='sample.mp4')
1403
+
1404
+ audio_clip = AudioFileClip("/tmp/sample.mp4")
1405
+ audio_subclip = audio_clip.subclip(st, et)
1406
+ sr = audio_subclip.fps
1407
+ y = audio_subclip.to_soundarray().mean(axis=1)
1408
+ audio_subclip.close()
1409
+ audio_clip.close()
1410
+
1411
+ # Compute spectrogram in librosa
1412
+ S_db = librosa.power_to_db(librosa.feature.melspectrogram(y, sr=sr), ref=np.max)
1413
+ # Compute width in cms from video_width
1414
+ width = video_width / plt.rcParams["figure.dpi"] + 0.63
1415
+ height = video_height / plt.rcParams["figure.dpi"]
1416
+ out = widgets.Output()
1417
+ with out:
1418
+ fig, ax = plt.subplots(figsize=(width, height))
1419
+ librosa.display.specshow(S_db, sr=sr, x_axis='time', ax=ax)
1420
+ ax.set_ylabel("Frequency (Hz)")
1421
+ else:
1422
+ out = widgets.Output()
1423
+
1424
+ vbox = widgets.VBox([label, video, out])
1425
+
1426
+ return vbox
1427
+
1428
+
1429
+ def visualize_pair_of_youtube_clips(clip_a, clip_b):
1430
+ yt_id_a = clip_a["youtube_id"]
1431
+ label_a = clip_a["sentence"]
1432
+ st_a, et_a = clip_a["time"]
1433
+
1434
+ yt_id_b = clip_b["youtube_id"]
1435
+ label_b = clip_b["sentence"]
1436
+ st_b, et_b = clip_b["time"]
1437
+
1438
+ # Show the clips side by side
1439
+ clip_a = visualize_youtube_clip(yt_id_a, st_a, et_a, label_a, show_spec=True)
1440
+ # clip_a = widgets.Output()
1441
+ # with clip_a:
1442
+ # visualize_youtube_clip(yt_id_a, st_a, et_a, label_a, show_spec=True)
1443
+
1444
+ clip_b = visualize_youtube_clip(yt_id_b, st_b, et_b, label_b, show_spec=True)
1445
+ # clip_b = widgets.Output()
1446
+ # with clip_b:
1447
+ # visualize_youtube_clip(yt_id_b, st_b, et_b, label_b, show_spec=True)
1448
+
1449
+ hbox = HBox([
1450
+ clip_a, clip_b
1451
+ ])
1452
+ display(hbox)
1453
+
1454
+
1455
+ def plot_1d(x: np.ndarray, figsize=(6, 2), title=None, xlabel=None, ylabel=None, show=True, **kwargs):
1456
+ assert (x.ndim == 1)
1457
+ fig, ax = plt.subplots(figsize=figsize)
1458
+ ax.grid(alpha=0.3)
1459
+ ax.set_title(title)
1460
+ ax.set_xlabel(xlabel)
1461
+ ax.set_ylabel(ylabel)
1462
+ ax.plot(np.arange(len(x)), x, **kwargs)
1463
+ if show:
1464
+ plt.show()
1465
+ else:
1466
+ plt.close()
1467
+ return fig
1468
+
1469
+
1470
+
1471
+ def make_grid(cols,rows):
1472
+ import streamlit as st
1473
+ grid = [0]*cols
1474
+ for i in range(cols):
1475
+ with st.container():
1476
+ grid[i] = st.columns(rows)
1477
+ return grid
1478
+
1479
+
1480
+ def display_clip(video_path, stime, etime, label=None):
1481
+ """Displays clip at index i."""
1482
+ assert exists(video_path), f"Video does not exist at {video_path}"
1483
+ display(
1484
+ show_small_clips_in_a_video(
1485
+ video_path, [(stime, etime)], labels=[label],
1486
+ ),
1487
+ )
1488
+
1489
+
1490
+ def countplot(df, column, title=None, rotation=90, ylabel="Count", figsize=(8, 5), ax=None, show=True, show_counts=False):
1491
+
1492
+ if ax is None:
1493
+ fig, ax = plt.subplots(figsize=figsize)
1494
+
1495
+ ax.grid(alpha=0.4)
1496
+ ax.set_xlabel(column)
1497
+ ax.set_ylabel(ylabel)
1498
+ ax.set_title(title)
1499
+
1500
+ data = dict(df[column].value_counts())
1501
+ # Extract keys and values from the dictionary
1502
+ categories = list(data.keys())
1503
+ counts = list(data.values())
1504
+
1505
+ # Create a countplot
1506
+ ax.bar(categories, counts)
1507
+ ax.set_xticklabels(categories, rotation=rotation)
1508
+
1509
+ # Show count values on top of bars
1510
+ if show_counts:
1511
+ max_v = max(counts)
1512
+ for i, v in enumerate(counts):
1513
+ delta = 0.01 * max_v
1514
+ ax.text(i, v + delta, str(v), ha="center")
1515
+
1516
+ if show:
1517
+ plt.show()
1518
+
1519
+
1520
+ def get_linspace_colors(cmap_name='viridis', num_colors = 10):
1521
+ import matplotlib.colors as mcolors
1522
+
1523
+ # Get the colormap object
1524
+ cmap = plt.cm.get_cmap(cmap_name)
1525
+
1526
+ # Get the evenly spaced indices
1527
+ indices = np.arange(0, 1, 1./num_colors)
1528
+
1529
+ # Get the corresponding colors from the colormap
1530
+ colors = [mcolors.to_hex(cmap(idx)) for idx in indices]
1531
+
1532
+ return colors
1533
+
1534
+
1535
+ def hex_to_rgb(colors):
1536
+ from PIL import ImageColor
1537
+ return [ImageColor.getcolor(c, "RGB") for c in colors]
1538
+
1539
+
1540
+ def plot_audio_feature(times, feature, feature_label="Feature", xlabel="Time", figsize=(20, 2)):
1541
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
1542
+ ax.grid(alpha=0.4)
1543
+ ax.set_xlabel(xlabel)
1544
+ ax.set_ylabel(feature_label)
1545
+ ax.set_yticks([])
1546
+
1547
+ ax.plot(times, feature, '--', linewidth=0.5)
1548
+ plt.show()
1549
+
1550
+
1551
+
1552
+ def compute_rms(y, frame_length=512):
1553
+ rms = librosa.feature.rms(y=y, frame_length=frame_length)[0]
1554
+ times = librosa.samples_to_time(frame_length * np.arange(len(rms)))
1555
+ return times, rms
1556
+
1557
+
1558
+ def plot_audio_features(path, label, show=True, show_video=True, features=["rms"], frame_length=512, figsize=(5, 2), return_features=False):
1559
+ # Load audio
1560
+ y, sr = librosa.load(path)
1561
+
1562
+ # Show video
1563
+ if show_video:
1564
+ if show:
1565
+ display(
1566
+ show_single_video_and_spectrogram(
1567
+ path, path, label=label, figsize=figsize,
1568
+ width=410,
1569
+ )
1570
+ )
1571
+ else:
1572
+ if show:
1573
+ # Show audio and spectrogram
1574
+ display(
1575
+ show_single_audio_with_spectrogram(path, label=label, figsize=figsize)
1576
+ )
1577
+
1578
+ feature_data = dict()
1579
+ for f in features:
1580
+ fn = eval(f"compute_{f}")
1581
+ args = dict(y=y, frame_length=frame_length)
1582
+ xvals, yvals = fn(**args)
1583
+ feature_data[f] = (xvals, yvals)
1584
+
1585
+ if show:
1586
+ display(
1587
+ plot_audio_feature(
1588
+ xvals, yvals, feature_label=f.upper(), figsize=(figsize[0] - 0.25, figsize[1]),
1589
+ )
1590
+ )
1591
+
1592
+ if return_features:
1593
+ return feature_data
1594
+
1595
+
1596
+ def rescale_frame(frame, scale=1.):
1597
+ """Rescales a frame by a factor of scale."""
1598
+ return frame.resize((int(frame.width * scale), int(frame.height * scale)))
1599
+
1600
+
1601
+ def save_gif(images, path, duration=None, fps=30):
1602
+ import imageio
1603
+ images = [np.asarray(image) for image in images]
1604
+ if fps is not None:
1605
+ imageio.mimsave(path, images, fps=fps)
1606
+ else:
1607
+ assert duration is not None
1608
+ imageio.mimsave(path, images, duration=duration)
1609
+
1610
+
1611
+ def show_subsampled_frames(frames, n_show, figsize=(15, 3), as_canvas=True):
1612
+ indices = np.arange(len(frames))
1613
+ indices = np.linspace(0, len(frames) - 1, n_show, dtype=int)
1614
+ show_frames = [frames[i] for i in indices]
1615
+ if as_canvas:
1616
+ return concat_images(show_frames)
1617
+ else:
1618
+ show_grid_of_images(show_frames, n_cols=n_show, figsize=figsize, subtitles=indices)
1619
+
1620
+
1621
+ def tensor_to_heatmap(x, scale=True, cmap="viridis", flip_vertically=False):
1622
+ import PIL
1623
+
1624
+ if isinstance(x, torch.Tensor):
1625
+ x = x.numpy()
1626
+
1627
+ if scale:
1628
+ x = (x - x.min()) / (x.max() - x.min())
1629
+
1630
+ cm = plt.get_cmap(cmap)
1631
+ if flip_vertically:
1632
+ x = np.flip(x, axis=0) # put low frequencies at the bottom in image
1633
+ x = cm(x)
1634
+ x = (x * 255).astype(np.uint8)
1635
+ if x.shape[-1] == 3:
1636
+ x = PIL.Image.fromarray(x, mode="RGB")
1637
+ elif x.shape[-1] == 4:
1638
+ x = PIL.Image.fromarray(x, mode="RGBA").convert("RGB")
1639
+ else:
1640
+ raise ValueError(f"Invalid shape {x.shape}")
1641
+ return x
1642
+
1643
+
1644
+ def batch_tensor_to_heatmap(x, scale=True, cmap="viridis", flip_vertically=False, resize=None):
1645
+ y = []
1646
+ for i in range(len(x)):
1647
+ h = tensor_to_heatmap(x[i], scale, cmap, flip_vertically)
1648
+ if resize is not None:
1649
+ h = h.resize(resize)
1650
+ y.append(h)
1651
+ return y
1652
+
1653
+
1654
+ def change_contrast(img, level):
1655
+ factor = (259 * (level + 255)) / (255 * (259 - level))
1656
+ def contrast(c):
1657
+ return 128 + factor * (c - 128)
1658
+ return img.point(contrast)
1659
+
1660
+
1661
+ def change_brightness(img, alpha):
1662
+ import PIL
1663
+ enhancer = PIL.ImageEnhance.Brightness(img)
1664
+ # to reduce brightness by 50%, use factor 0.5
1665
+ img = enhancer.enhance(alpha)
1666
+ return img
1667
+
1668
+
1669
+ def draw_horizontal_lines(image, y_values, color=(255, 0, 0), colors=None, line_thickness=2):
1670
+ """
1671
+ Draw horizontal lines on a PIL image at specified Y positions.
1672
+
1673
+ Args:
1674
+ image (PIL.Image.Image): The input PIL image.
1675
+ y_values (list or int): List of Y positions where lines will be drawn.
1676
+ If a single integer is provided, a line will be drawn at that Y position.
1677
+ color (tuple): RGB color tuple (e.g., (255, 0, 0) for red).
1678
+ line_thickness (int): Thickness of the lines.
1679
+
1680
+ Returns:
1681
+ PIL.Image.Image: The PIL image with the drawn lines.
1682
+ """
1683
+ image = image.copy()
1684
+
1685
+ if isinstance(color, str):
1686
+ color = PIL.ImageColor.getcolor(color, "RGB")
1687
+
1688
+ if colors is None:
1689
+ colors = [color] * len(y_values)
1690
+ else:
1691
+ if isinstance(colors[0], str):
1692
+ colors = [PIL.ImageColor.getcolor(c, "RGB") for c in colors]
1693
+
1694
+ if isinstance(y_values, int):
1695
+ y_values = [y_values]
1696
+
1697
+ # Create a drawing context on the image
1698
+ draw = PIL.ImageDraw.Draw(image)
1699
+
1700
+ if isinstance(y_values, int):
1701
+ y_values = [y_values]
1702
+
1703
+ for y, c in zip(y_values, colors):
1704
+ draw.line([(0, y), (image.width, y)], fill=c, width=line_thickness)
1705
+
1706
+ return image
1707
+
1708
+
1709
+ def draw_vertical_lines(image, x_values, color=(255, 0, 0), colors=None, line_thickness=2):
1710
+ """
1711
+ Draw vertical lines on a PIL image at specified X positions.
1712
+
1713
+ Args:
1714
+ image (PIL.Image.Image): The input PIL image.
1715
+ x_values (list or int): List of X positions where lines will be drawn.
1716
+ If a single integer is provided, a line will be drawn at that X position.
1717
+ color (tuple): RGB color tuple (e.g., (255, 0, 0) for red).
1718
+ line_thickness (int): Thickness of the lines.
1719
+
1720
+ Returns:
1721
+ PIL.Image.Image: The PIL image with the drawn lines.
1722
+ """
1723
+ image = image.copy()
1724
+
1725
+ if isinstance(color, str):
1726
+ color = PIL.ImageColor.getcolor(color, "RGB")
1727
+
1728
+ if colors is None:
1729
+ colors = [color] * len(x_values)
1730
+ else:
1731
+ if isinstance(colors[0], str):
1732
+ colors = [PIL.ImageColor.getcolor(c, "RGB") for c in colors]
1733
+
1734
+ if isinstance(x_values, int):
1735
+ x_values = [x_values]
1736
+
1737
+ # Create a drawing context on the image
1738
+ draw = PIL.ImageDraw.Draw(image)
1739
+
1740
+ if isinstance(x_values, int):
1741
+ x_values = [x_values]
1742
+
1743
+ for x, c in zip(x_values, colors):
1744
+ draw.line([(x, 0), (x, image.height)], fill=c, width=line_thickness)
1745
+
1746
+ return image
1747
+
1748
+
1749
+ def show_arrow_on_image(image, start_loc, end_loc, color="red", thickness=3):
1750
+ """Draw a line on PIL image from start_loc to end_loc."""
1751
+ image = image.copy()
1752
+ color = get_predominant_color(color)
1753
+
1754
+ # Create an instance of the ImageDraw class
1755
+ draw = ImageDraw.Draw(image)
1756
+
1757
+ # Draw the bounding box on the image
1758
+ draw.line([start_loc, end_loc], fill=color, width=thickness)
1759
+
1760
+ return image
1761
+
1762
+
1763
+ def draw_arrow_on_image_cv2(image, start_loc, end_loc, color="red", thickness=2, both_ends=False):
1764
+ image = image.copy()
1765
+ image = np.asarray(image)
1766
+ if isinstance(color, str):
1767
+ color = PIL.ImageColor.getcolor(color, "RGB")
1768
+ image = cv2.arrowedLine(image, start_loc, end_loc, color, thickness)
1769
+ if both_ends:
1770
+ image = cv2.arrowedLine(image, end_loc, start_loc, color, thickness)
1771
+ return PIL.Image.fromarray(image)
1772
+
1773
+
1774
+ def draw_arrow_with_text(image, start_loc, end_loc, text="", color="red", thickness=2, font_size=20, both_ends=False, delta=5):
1775
+ image = np.asarray(image)
1776
+ if isinstance(color, str):
1777
+ color = PIL.ImageColor.getcolor(color, "RGB")
1778
+
1779
+ # Calculate the center point between start_loc and end_loc
1780
+ center_x = (start_loc[0] + end_loc[0]) // 2
1781
+ center_y = (start_loc[1] + end_loc[1]) // 2
1782
+ center_point = (center_x, center_y)
1783
+
1784
+ # Draw the arrowed line
1785
+ image = cv2.arrowedLine(image, start_loc, end_loc, color, thickness)
1786
+ if both_ends:
1787
+ image = cv2.arrowedLine(image, end_loc, start_loc, color, thickness)
1788
+
1789
+ # Create a PIL image from the NumPy array for drawing text
1790
+ image_with_text = Image.fromarray(image)
1791
+ draw = PIL.ImageDraw.Draw(image_with_text)
1792
+
1793
+ # Calculate the text size
1794
+ # font = PIL.ImageFont.truetype("arial.ttf", font_size)
1795
+ # This gives an error: "OSError: cannot open resource", as a hack, use the following
1796
+ text_width, text_height = draw.textsize(text)
1797
+
1798
+ # Calculate the position to center the text
1799
+ text_x = center_x - (text_width // 2) - delta
1800
+ text_y = center_y - (text_height // 2)
1801
+
1802
+ # Draw the text
1803
+ draw.text((text_x, text_y), text, color)
1804
+
1805
+ return image_with_text
1806
+
1807
+
1808
+ def draw_arrowed_line(image, start_loc, end_loc, color="red", thickness=2):
1809
+ """
1810
+ Draw an arrowed line on a PIL image from a starting point to an ending point.
1811
+
1812
+ Args:
1813
+ image (PIL.Image.Image): The input PIL image.
1814
+ start_loc (tuple): Starting point (x, y) for the arrowed line.
1815
+ end_loc (tuple): Ending point (x, y) for the arrowed line.
1816
+ color (str): Color of the line (e.g., 'red', 'green', 'blue').
1817
+ thickness (int): Thickness of the line and arrowhead.
1818
+
1819
+ Returns:
1820
+ PIL.Image.Image: The PIL image with the drawn arrowed line.
1821
+ """
1822
+ image = image.copy()
1823
+ if isinstance(color, str):
1824
+ color = PIL.ImageColor.getcolor(color, "RGB")
1825
+
1826
+
1827
+ # Create a drawing context on the image
1828
+ draw = ImageDraw.Draw(image)
1829
+
1830
+ # Draw a line from start to end
1831
+ draw.line([start_loc, end_loc], fill=color, width=thickness)
1832
+
1833
+ # Calculate arrowhead points
1834
+ arrow_size = 10 # Size of the arrowhead
1835
+ dx = end_loc[0] - start_loc[0]
1836
+ dy = end_loc[1] - start_loc[1]
1837
+ length = (dx ** 2 + dy ** 2) ** 0.5
1838
+ cos_theta = dx / length
1839
+ sin_theta = dy / length
1840
+ x1 = end_loc[0] - arrow_size * cos_theta
1841
+ y1 = end_loc[1] - arrow_size * sin_theta
1842
+ x2 = end_loc[0] - arrow_size * sin_theta
1843
+ y2 = end_loc[1] + arrow_size * cos_theta
1844
+ x3 = end_loc[0] + arrow_size * sin_theta
1845
+ y3 = end_loc[1] - arrow_size * cos_theta
1846
+
1847
+ # Draw the arrowhead triangle
1848
+ draw.polygon([end_loc, (x1, y1), (x2, y2), (x3, y3)], fill=color)
1849
+
1850
+ return image
1851
+
1852
+
1853
+ def center_crop_to_fraction(image, frac=0.5):
1854
+ """Center crop an image to a fraction of its original size."""
1855
+ width, height = image.size
1856
+ new_width = int(width * frac)
1857
+ new_height = int(height * frac)
1858
+ left = (width - new_width) // 2
1859
+ top = (height - new_height) // 2
1860
+ right = (width + new_width) // 2
1861
+ bottom = (height + new_height) // 2
1862
+ return image.crop((left, top, right, bottom))
1863
+
1864
+
1865
+ def decord_load_frames(vr, frame_indices):
1866
+ if isinstance(frame_indices, int):
1867
+ frame_indices = [frame_indices]
1868
+ frames = vr.get_batch(frame_indices).asnumpy()
1869
+ frames = [Image.fromarray(frame) for frame in frames]
1870
+ return frames
1871
+
1872
+
1873
+ def paste_mask_on_image(original_image, bounding_box, mask):
1874
+ """
1875
+ Paste a 2D mask onto the original image at the location specified by the bounding box.
1876
+
1877
+ Parameters:
1878
+ - original_image (PIL.Image): The original image.
1879
+ - bounding_box (tuple): Bounding box coordinates (left, top, right, bottom).
1880
+ - mask (PIL.Image): The 2D mask.
1881
+
1882
+ Returns:
1883
+ - PIL.Image: Image with the mask pasted on it.
1884
+
1885
+ Example:
1886
+ ```
1887
+ original_image = Image.open('original.jpg')
1888
+ bounding_box = (100, 100, 200, 200)
1889
+ mask = Image.open('mask.png')
1890
+ result_image = paste_mask_on_image(original_image, bounding_box, mask)
1891
+ result_image.show()
1892
+ ```
1893
+ """
1894
+ # Create a copy of the original image to avoid modifying the input image
1895
+ result_image = original_image.copy()
1896
+
1897
+ # Crop the mask to the size of the bounding box
1898
+ mask_cropped = mask.crop((0, 0, bounding_box[2] - bounding_box[0], bounding_box[3] - bounding_box[1]))
1899
+
1900
+ # Paste the cropped mask onto the original image at the specified location
1901
+ result_image.paste(mask_cropped, (bounding_box[0], bounding_box[1]))
1902
+
1903
+ return result_image
1904
+
1905
+
1906
+ def display_images_as_video_moviepy(image_list, fps=5, show=True):
1907
+ """
1908
+ Display a list of PIL images as a video in Jupyter Notebook using MoviePy.
1909
+
1910
+ Parameters:
1911
+ - image_list (list): List of PIL images.
1912
+ - fps (int): Frames per second for the video.
1913
+ - show (bool): Whether to display the video in the notebook.
1914
+
1915
+ Example:
1916
+ ```
1917
+ image_list = [Image.open('frame1.jpg'), Image.open('frame2.jpg'), ...]
1918
+ display_images_as_video_moviepy(image_list, fps=10)
1919
+ ```
1920
+ """
1921
+ from IPython.display import display
1922
+ from moviepy.editor import ImageSequenceClip
1923
+
1924
+ image_list = list(map(np.asarray, image_list))
1925
+ clip = ImageSequenceClip(image_list, fps=fps)
1926
+ if show:
1927
+ display(clip.ipython_display(width=200))
1928
+ os.remove("__temp__.mp4")
1929
+
1930
+
1931
+ def resize_height(img, H):
1932
+ w, h = img.size
1933
+ asp_ratio = w / h
1934
+ W = np.ceil(asp_ratio * H).astype(int)
1935
+ return img.resize((W, H))
1936
+
1937
+
1938
+ def resize_width(img, W):
1939
+ w, h = img.size
1940
+ asp_ratio = w / h
1941
+ H = int(W / asp_ratio)
1942
+ return img.resize((W, H))
1943
+
1944
+
1945
+ def resized_minor_side(img, size=256):
1946
+ H, W = img.size
1947
+ if H < W:
1948
+ H_new = size
1949
+ W_new = int(size * W / H)
1950
+ return img.resize((W_new, H_new))
1951
+ else:
1952
+ W_new = size
1953
+ H_new = int(size * H / W)
1954
+ return img.resize((W_new, H_new))
1955
+
1956
+
1957
+ def brighten_image(img, alpha=1.2):
1958
+ enhancer = PIL.ImageEnhance.Brightness(img)
1959
+ img = enhancer.enhance(alpha)
1960
+ return img
1961
+
1962
+
1963
+ def darken_image(img, alpha=0.8):
1964
+ enhancer = PIL.ImageEnhance.Brightness(img)
1965
+ img = enhancer.enhance(alpha)
1966
+ return img
1967
+
1968
+
1969
+ def fig2img(fig):
1970
+ """Convert a Matplotlib figure to a PIL Image and return it"""
1971
+ import io
1972
+ buf = io.BytesIO()
1973
+ fig.savefig(buf)
1974
+ buf.seek(0)
1975
+ img = Image.open(buf)
1976
+ return img
1977
+
1978
+
1979
+ def show_temporal_tsne(
1980
+ tsne,
1981
+ timestamps=None,
1982
+ title="tSNE: feature vectors over time",
1983
+ cmap='viridis',
1984
+ ax=None,
1985
+ fig=None,
1986
+ show=True,
1987
+ num_ticks=10,
1988
+ return_as_pil=False,
1989
+ dpi=100,
1990
+ label='Time (s)',
1991
+ figsize=(6, 4),
1992
+ s=None,
1993
+ ):
1994
+
1995
+ if timestamps is None:
1996
+ timestamps = np.arange(len(tsne))
1997
+
1998
+ if ax is None or fig is None:
1999
+ fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
2000
+
2001
+ cmap = plt.get_cmap(cmap)
2002
+ scatter = ax.scatter(
2003
+ tsne[:, 0], tsne[:, 1], c=np.arange(len(tsne)), cmap=cmap, s=s,
2004
+ edgecolor='k', linewidth=0.5,
2005
+ )
2006
+
2007
+ ax.grid(alpha=0.4)
2008
+ ax.set_title(f"{title}", fontsize=11)
2009
+ ax.set_xlabel("$z_{1}$")
2010
+ ax.set_ylabel("$z_{2}$")
2011
+
2012
+ # Create a colorbar
2013
+ cbar = fig.colorbar(scatter, ax=ax, label=label)
2014
+
2015
+ # Set custom ticks and labels on the colorbar
2016
+ ticks = np.linspace(0, len(tsne) - 1, num_ticks, dtype=int)
2017
+ tick_labels = np.round(timestamps[ticks], 1)
2018
+ cbar.set_ticks(ticks)
2019
+ cbar.set_ticklabels(tick_labels)
2020
+
2021
+ if show:
2022
+ plt.show()
2023
+ else:
2024
+ if return_as_pil:
2025
+ plt.tight_layout(pad=0.2)
2026
+ # fig.canvas.draw()
2027
+ # image = PIL.Image.frombytes(
2028
+ # 'RGB',
2029
+ # fig.canvas.get_width_height(),
2030
+ # fig.canvas.tostring_rgb(),
2031
+ # )
2032
+ # return image
2033
+
2034
+ # Return as PIL Image without displaying the plt figure
2035
+ image = fig2img(fig)
2036
+ plt.close(fig)
2037
+ return image
2038
+
2039
+
2040
+ def mark_keypoints(image, keypoints, color=(255, 255, 0), radius=1):
2041
+ """
2042
+ Marks keypoints on an image with a given color and radius.
2043
+
2044
+ :param image: The input PIL image.
2045
+ :param keypoints: A list of (x, y) tuples representing the keypoints.
2046
+ :param color: The color to use for the keypoints (default: red).
2047
+ :param radius: The radius of the circle to draw for each keypoint (default: 5).
2048
+ :return: A new PIL image with the keypoints marked.
2049
+ """
2050
+ # Make a copy of the image to avoid modifying the original
2051
+ image_copy = image.copy()
2052
+
2053
+ # Create a draw object to add graphical elements
2054
+ draw = ImageDraw.Draw(image_copy)
2055
+
2056
+ # Loop through each keypoint and draw a circle
2057
+ for x, y in keypoints:
2058
+ # Draw a circle with the specified radius and color
2059
+ draw.ellipse(
2060
+ (x - radius, y - radius, x + radius, y + radius),
2061
+ fill=color,
2062
+ width=2
2063
+ )
2064
+
2065
+ return image_copy
2066
+
2067
+
2068
+ def draw_line_on_image(image, x_coords, y_coords, color=(255, 255, 0), width=3):
2069
+ """
2070
+ Draws a line on an image given lists of x and y coordinates.
2071
+
2072
+ :param image: The input PIL image.
2073
+ :param x_coords: List of x-coordinates for the line.
2074
+ :param y_coords: List of y-coordinates for the line.
2075
+ :param color: Color of the line in RGB (default is red).
2076
+ :param width: Width of the line (default is 3).
2077
+ :return: The PIL image with the line drawn.
2078
+ """
2079
+ image = image.copy()
2080
+
2081
+ # Ensure the number of x and y coordinates are the same
2082
+ if len(x_coords) != len(y_coords):
2083
+ raise ValueError("x_coords and y_coords must have the same length")
2084
+
2085
+ # Create a draw object to draw on the image
2086
+ draw = ImageDraw.Draw(image)
2087
+
2088
+ # Create a list of (x, y) coordinate tuples
2089
+ coordinates = list(zip(x_coords, y_coords))
2090
+
2091
+ # Draw the line connecting the coordinates
2092
+ draw.line(coordinates, fill=color, width=width)
2093
+
2094
+ return image
2095
+
2096
+
2097
+ def add_binary_strip_vertically(
2098
+ image,
2099
+ binary_vector,
2100
+ strip_width=15,
2101
+ one_color="yellow",
2102
+ zero_color="gray",
2103
+ ):
2104
+ """
2105
+ Add a binary strip to the right side of an image.
2106
+
2107
+ :param image: PIL Image to which the strip will be added.
2108
+ :param binary_vector: Binary vector of length 512 representing the strip.
2109
+ :param strip_width: Width of the strip to be added.
2110
+ :param one_color: Color for "1" pixels (default: red).
2111
+ :param zero_color: Color for "0" pixels (default: white).
2112
+ :return: New image with the binary strip added on the right side.
2113
+ """
2114
+ one_color = PIL.ImageColor.getrgb(one_color)
2115
+ zero_color = PIL.ImageColor.getrgb(zero_color)
2116
+
2117
+ height = image.height
2118
+ if len(binary_vector) != height:
2119
+ raise ValueError("Binary vector must be of length 512")
2120
+
2121
+ # Create a new strip with the specified width and 512 height
2122
+ strip = PIL.Image.new("RGB", (strip_width, height))
2123
+
2124
+ # Fill the strip based on the binary vector
2125
+ pixels = strip.load()
2126
+ for i in range(height):
2127
+ color = one_color if binary_vector[i] == 1 else zero_color
2128
+ for w in range(strip_width):
2129
+ pixels[w, i] = color
2130
+
2131
+ # Combine the original image with the new strip
2132
+ # new_image = PIL.Image.new("RGB", (image.width + strip_width, height))
2133
+ # new_image.paste(image, (0, 0))
2134
+ # new_image.paste(strip, (image.width, 0))
2135
+ new_image = image.copy()
2136
+ new_image.paste(strip, (image.width - strip_width, 0))
2137
+
2138
+ return new_image
2139
+
2140
+
2141
+ def add_binary_strip_horizontally(
2142
+ image,
2143
+ binary_vector,
2144
+ strip_height=15,
2145
+ one_color="limegreen",
2146
+ zero_color="gray",
2147
+ ):
2148
+ """
2149
+ Add a binary strip to the top of an image.
2150
+
2151
+ :param image: PIL Image to which the strip will be added.
2152
+ :param binary_vector: Binary vector of length 512 representing the strip.
2153
+ :param strip_height: Height of the strip to be added.
2154
+ :param one_color: Color for "1" pixels, accepts color names or hex (default: red).
2155
+ :param zero_color: Color for "0" pixels, accepts color names or hex (default: white).
2156
+ :return: New image with the binary strip added at the top.
2157
+ """
2158
+ width = image.width
2159
+ if len(binary_vector) != width:
2160
+ raise ValueError("Binary vector must be of length 512")
2161
+
2162
+ # Convert colors to RGB tuples
2163
+ one_color_rgb = PIL.ImageColor.getrgb(one_color)
2164
+ zero_color_rgb = PIL.ImageColor.getrgb(zero_color)
2165
+
2166
+ # Create a new strip with the specified height and 512 width
2167
+ strip = PIL.Image.new("RGB", (width, strip_height))
2168
+
2169
+ # Fill the strip based on the binary vector
2170
+ pixels = strip.load()
2171
+ for i in range(width):
2172
+ color = one_color_rgb if binary_vector[i] == 1 else zero_color_rgb
2173
+ for h in range(strip_height):
2174
+ pixels[i, h] = color
2175
+
2176
+ # Combine the original image with the new strip
2177
+ # new_image = PIL.Image.new("RGB", (width, image.height + strip_height))
2178
+ # new_image.paste(strip, (0, 0))
2179
+ # new_image.paste(image, (0, strip_height))
2180
+ new_image = image.copy()
2181
+ new_image.paste(strip, (0, 0))
2182
+
2183
+ return new_image
2184
+
2185
+
2186
+ # Define a function to increase font sizes for a specific plot
2187
+ def increase_font_sizes(ax, font_scale=1.6):
2188
+ for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
2189
+ ax.get_xticklabels() + ax.get_yticklabels()):
2190
+ item.set_fontsize(item.get_fontsize() * font_scale)
2191
+
2192
+
2193
+
2194
+ def cut_fraction_of_bbox(image, box, frac=0.7):
2195
+ """
2196
+ Cuts the image such that the box occupies a fraction of the image.
2197
+ """
2198
+ W, H = image.size
2199
+ x1, y1, x2, y2 = box
2200
+ w = x2 - x1
2201
+ h = y2 - y1
2202
+ new_w = int(w / frac)
2203
+ new_h = int(h / frac)
2204
+ x1_new = max(0, x1 - (new_w - w) // 2)
2205
+ x2_new = min(W, x2 + (new_w - w) // 2)
2206
+ y1_new = max(0, y1 - (new_h - h) // 2)
2207
+ y2_new = min(H, y2 + (new_h - h) // 2)
2208
+ return image.crop((x1_new, y1_new, x2_new, y2_new))