"""Helpers for visualization"""
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import cv2
import PIL
from PIL import Image, ImageOps, ImageDraw
from os.path import exists
import librosa.display
import pandas as pd
import itertools
import librosa
from tqdm import tqdm
from IPython.display import Audio, Markdown, display
from ipywidgets import Button, HBox, VBox, Text, Label, HTML, widgets
from shared.utils.log import tqdm_iterator
import warnings
warnings.filterwarnings("ignore")
try:
import torchvideotransforms
except:
print("Failed to import torchvideotransforms. Proceeding without.")
print("Please install using:")
print("pip install git+https://github.com/hassony2/torch_videovision")
# define predominanat colors
COLORS = {
"pink": (242, 116, 223),
"cyan": (46, 242, 203),
"red": (255, 0, 0),
"green": (0, 255, 0),
"blue": (0, 0, 255),
"yellow": (255, 255, 0),
}
def get_predominant_color(color_key, mode="RGB", alpha=0):
assert color_key in COLORS.keys(), f"Unknown color key: {color_key}"
if mode == "RGB":
return COLORS[color_key]
elif mode == "RGBA":
return COLORS[color_key] + (alpha,)
def show_single_image(image: np.ndarray, figsize: tuple = (8, 8), title: str = None, cmap: str = None, ticks=False):
"""Show a single image."""
fig, ax = plt.subplots(1, 1, figsize=figsize)
if isinstance(image, Image.Image):
image = np.asarray(image)
ax.set_title(title)
ax.imshow(image, cmap=cmap)
if not ticks:
ax.set_xticks([])
ax.set_yticks([])
plt.show()
def show_grid_of_images(
images: np.ndarray, n_cols: int = 4, figsize: tuple = (8, 8), subtitlesize=14,
cmap=None, subtitles=None, title=None, save=False, savepath="sample.png", titlesize=20,
ysuptitle=0.8, xlabels=None, sizealpha=0.7, show=True, row_labels=None, aspect=None,
):
"""Show a grid of images."""
n_cols = min(n_cols, len(images))
copy_of_images = images.copy()
for i, image in enumerate(copy_of_images):
if isinstance(image, Image.Image):
image = np.asarray(image)
copy_of_images[i] = image
if subtitles is None:
subtitles = [None] * len(images)
if xlabels is None:
xlabels = [None] * len(images)
if row_labels is None:
num_rows = int(np.ceil(len(images) / n_cols))
row_labels = [None] * num_rows
n_rows = int(np.ceil(len(images) / n_cols))
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
if len(images) == 1:
axes = np.array([[axes]])
for i, ax in enumerate(axes.flat):
if i < len(copy_of_images):
if len(copy_of_images[i].shape) == 2 and cmap is None:
cmap="gray"
ax.imshow(copy_of_images[i], cmap=cmap, aspect=aspect)
ax.set_title(subtitles[i], fontsize=subtitlesize)
ax.set_xlabel(xlabels[i], fontsize=sizealpha * subtitlesize)
ax.set_xticks([])
ax.set_yticks([])
col_idx = i % n_cols
if col_idx == 0:
ax.set_ylabel(row_labels[i // n_cols], fontsize=sizealpha * subtitlesize)
fig.tight_layout()
plt.suptitle(title, y=ysuptitle, fontsize=titlesize)
if save:
plt.savefig(savepath, bbox_inches='tight')
if show:
plt.show()
def add_text_to_image(image, text):
from PIL import ImageFont
from PIL import ImageDraw
# # resize image
# image = image.resize((image.size[0] * 2, image.size[1] * 2))
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
# font = ImageFont.load("arial.pil")
# font = ImageFont.FreeTypeFont(size=20)
# font = ImageFont.truetype("arial.ttf", 28, encoding="unic")
# change fontsize
# select color = black if image is mostly white
if np.mean(image) > 200:
draw.text((0, 0), text, (0,0,0), font=font)
else:
draw.text((0, 0), text, (255,255,255), font=font)
# draw.text((0, 0), text, (255,255,255), font=font)
return image
def show_keypoint_matches(
img1, kp1, img2, kp2, matches,
K=10, figsize=(10, 5), drawMatches_args=dict(matchesThickness=3, singlePointColor=(0, 0, 0)),
choose_matches="random",
):
"""Displays matches found in the pair of images"""
if choose_matches == "random":
selected_matches = np.random.choice(matches, K)
elif choose_matches == "all":
K = len(matches)
selected_matches = matches
elif choose_matches == "topk":
selected_matches = matches[:K]
else:
raise ValueError(f"Unknown value for choose_matches: {choose_matches}")
# color each match with a different color
cmap = matplotlib.cm.get_cmap('gist_rainbow', K)
colors = [[int(x*255) for x in cmap(i)[:3]] for i in np.arange(0,K)]
drawMatches_args.update({"matchColor": -1, "singlePointColor": (100, 100, 100)})
img3 = cv2.drawMatches(img1, kp1, img2, kp2, selected_matches, outImg=None, **drawMatches_args)
show_single_image(
img3,
figsize=figsize,
title=f"[{choose_matches.upper()}] Selected K = {K} matches between the pair of images.",
)
return img3
def draw_kps_on_image(image: np.ndarray, kps: np.ndarray, color=COLORS["red"], radius=3, thickness=-1, return_as="PIL"):
"""
Draw keypoints on image.
Args:
image: Image to draw keypoints on.
kps: Keypoints to draw. Note these should be in (x, y) format.
"""
if isinstance(image, Image.Image):
image = np.asarray(image)
if isinstance(color, str):
color = PIL.ImageColor.getrgb(color)
colors = [color] * len(kps)
elif isinstance(color, tuple):
colors = [color] * len(kps)
elif isinstance(color, list):
colors = [PIL.ImageColor.getrgb(c) for c in color]
assert len(colors) == len(kps), f"Number of colors ({len(colors)}) must be equal to number of keypoints ({len(kps)})"
for kp, c in zip(kps, colors):
image = cv2.circle(
image.copy(), (int(kp[0]), int(kp[1])), radius=radius, color=c, thickness=thickness)
if return_as == "PIL":
return Image.fromarray(image)
return image
def get_concat_h(im1, im2):
"""Concatenate two images horizontally"""
dst = Image.new('RGB', (im1.width + im2.width, im1.height))
dst.paste(im1, (0, 0))
dst.paste(im2, (im1.width, 0))
return dst
def get_concat_v(im1, im2):
"""Concatenate two images vertically"""
dst = Image.new('RGB', (im1.width, im1.height + im2.height))
dst.paste(im1, (0, 0))
dst.paste(im2, (0, im1.height))
return dst
def show_images_with_keypoints(images: list, kps: list, radius=15, color=(0, 220, 220), figsize=(10, 8)):
assert len(images) == len(kps)
# generate
images_with_kps = []
for i in range(len(images)):
img_with_kps = draw_kps_on_image(images[i], kps[i], radius=radius, color=color, return_as="PIL")
images_with_kps.append(img_with_kps)
# show
show_grid_of_images(images_with_kps, n_cols=len(images), figsize=figsize)
def set_latex_fonts(usetex=True, fontsize=14, show_sample=False, **kwargs):
try:
plt.rcParams.update({
"text.usetex": usetex,
"font.family": "serif",
# "font.serif": ["Computer Modern Romans"],
"font.size": fontsize,
**kwargs,
})
if show_sample:
plt.figure()
plt.title("Sample $y = x^2$")
plt.plot(np.arange(0, 10), np.arange(0, 10)**2, "--o")
plt.grid()
plt.show()
except:
print("Failed to setup LaTeX fonts. Proceeding without.")
pass
def plot_2d_points(
list_of_points_2d,
colors=None,
sizes=None,
markers=None,
alpha=0.75,
h=256,
w=256,
ax=None,
save=True,
savepath="test.png",
):
if ax is None:
fig, ax = plt.subplots(1, 1)
ax.set_xlim([0, w])
ax.set_ylim([0, h])
if sizes is None:
sizes = [0.1 for _ in range(len(list_of_points_2d))]
if colors is None:
colors = ["gray" for _ in range(len(list_of_points_2d))]
if markers is None:
markers = ["o" for _ in range(len(list_of_points_2d))]
for points_2d, color, s, m in zip(list_of_points_2d, colors, sizes, markers):
ax.scatter(points_2d[:, 0], points_2d[:, 1], s=s, alpha=alpha, color=color, marker=m)
if save:
plt.savefig(savepath, bbox_inches='tight')
def plot_2d_points_on_image(
image,
img_alpha=1.0,
ax=None,
list_of_points_2d=[],
scatter_args=dict(),
):
if ax is None:
fig, ax = plt.subplots(1, 1)
ax.imshow(image, alpha=img_alpha)
scatter_args["save"] = False
plot_2d_points(list_of_points_2d, ax=ax, **scatter_args)
# invert the axis
ax.set_ylim(ax.get_ylim()[::-1])
def compare_landmarks(
image, ground_truth_landmarks, v2d, predicted_landmarks,
save=False, savepath="compare_landmarks.png", num_kps_to_show=-1,
show_matches=True,
):
# show GT landmarks on image
fig, axes = plt.subplots(1, 3, figsize=(11, 4))
ax = axes[0]
plot_2d_points_on_image(
image,
list_of_points_2d=[ground_truth_landmarks],
scatter_args=dict(sizes=[15], colors=["limegreen"]),
ax=ax,
)
ax.set_title("GT landmarks", fontsize=12)
# since the projected points are inverted, using 180 degree rotation about z-axis
ax = axes[1]
plot_2d_points_on_image(
image,
list_of_points_2d=[v2d, predicted_landmarks],
scatter_args=dict(sizes=[0.08, 15], markers=["o", "x"], colors=["royalblue", "red"]),
ax=ax,
)
ax.set_title("Projection of predicted mesh", fontsize=12)
# plot the ground truth and predicted landmarks on the same image
ax = axes[2]
plot_2d_points_on_image(
image,
list_of_points_2d=[
ground_truth_landmarks[:num_kps_to_show],
predicted_landmarks[:num_kps_to_show],
],
scatter_args=dict(sizes=[15, 15], markers=["o", "x"], colors=["limegreen", "red"]),
ax=ax,
img_alpha=0.5,
)
ax.set_title("GT and predicted landmarks", fontsize=12)
if show_matches:
for i in range(num_kps_to_show):
x_values = [ground_truth_landmarks[i, 0], predicted_landmarks[i, 0]]
y_values = [ground_truth_landmarks[i, 1], predicted_landmarks[i, 1]]
ax.plot(x_values, y_values, color="yellow", markersize=1, linewidth=2.)
fig.tight_layout()
if save:
plt.savefig(savepath, bbox_inches="tight")
def plot_historgam_values(
X, display_vals=False,
bins=50, figsize=(8, 5),
show_mean=True,
xlabel=None, ylabel=None,
ax=None, title=None, show=False,
**kwargs,
):
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)
ax.hist(X, bins=bins, **kwargs)
if title is None:
title = "Histogram of values"
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
if display_vals:
x, counts = np.unique(X, return_counts=True)
# sort_indices = np.argsort(x)
# x = x[sort_indices]
# counts = counts[sort_indices]
# for i in range(len(x)):
# ax.text(x[i], counts[i], counts[i], ha='center', va='bottom')
ax.grid(alpha=0.3)
if show_mean:
mean = np.mean(X)
mean_string = f"$\mu$: {mean:.2f}"
ax.set_title(title + f" ({mean_string}) ")
else:
ax.set_title(title)
if not show:
return ax
else:
plt.show()
"""Helper functions for all kinds of 2D/3D visualization"""
def bokeh_2d_scatter(x, y, desc, figsize=(700, 700), colors=None, use_nb=False, title="Bokeh scatter plot"):
import matplotlib.colors as mcolors
from bokeh.plotting import figure, output_file, show, ColumnDataSource
from bokeh.models import HoverTool
from bokeh.io import output_notebook
if use_nb:
output_notebook()
# define colors to be assigned
if colors is None:
# applies the same color
# create a color iterator: pick a random color and apply it to all points
# colors = [np.random.choice(itertools.cycle(palette))] * len(x)
colors = [np.random.choice(["red", "green", "blue", "yellow", "pink", "black", "gray"])] * len(x)
# # applies different colors
# colors = np.array([ [r, g, 150] for r, g in zip(50 + 2*x, 30 + 2*y) ], dtype="uint8")
# define the df of data to plot
source = ColumnDataSource(
data=dict(
x=x,
y=y,
desc=desc,
color=colors,
)
)
# define the attributes to show on hover
hover = HoverTool(
tooltips=[
("index", "$index"),
("(x, y)", "($x, $y)"),
("Desc", "@desc"),
]
)
p = figure(
plot_width=figsize[0], plot_height=figsize[1], tools=[hover], title=title,
)
p.circle('x', 'y', size=10, source=source, fill_color="color")
show(p)
def bokeh_2d_scatter_new(
df, x, y, hue, label, color_column=None, size_col=None,
figsize=(700, 700), use_nb=False, title="Bokeh scatter plot",
legend_loc="bottom_left", edge_color="black", audio_col=None,
):
from bokeh.plotting import figure, output_file, show, ColumnDataSource
from bokeh.models import HoverTool
from bokeh.io import output_notebook
if use_nb:
output_notebook()
assert {x, y, hue, label}.issubset(set(df.keys()))
if isinstance(color_column, str) and color_column in df.keys():
color_column_name = color_column
else:
import matplotlib.colors as mcolors
colors = list(mcolors.BASE_COLORS.keys()) + list(mcolors.TABLEAU_COLORS.values())
# colors = list(mcolors.BASE_COLORS.keys())
colors = itertools.cycle(np.unique(colors))
hue_to_color = dict()
unique_hues = np.unique(df[hue].values)
for _hue in unique_hues:
hue_to_color[_hue] = next(colors)
df["color"] = df[hue].apply(lambda k: hue_to_color[k])
color_column_name = "color"
if size_col is not None:
assert isinstance(size_col, str) and size_col in df.keys()
else:
sizes = [10.] * len(df)
df["size"] = sizes
size_col = "size"
source = ColumnDataSource(
dict(
x = df[x].values,
y = df[y].values,
hue = df[hue].values,
label = df[label].values,
color = df[color_column_name].values,
edge_color = [edge_color] * len(df),
sizes = df[size_col].values,
)
)
# define the attributes to show on hover
hover = HoverTool(
tooltips=[
("index", "$index"),
("(x, y)", "($x, $y)"),
("Desc", "@label"),
("Cluster", "@hue"),
]
)
p = figure(
plot_width=figsize[0],
plot_height=figsize[1],
tools=["pan","wheel_zoom","box_zoom","save","reset","help"] + [hover],
title=title,
)
p.circle(
'x', 'y', size="sizes",
source=source, fill_color="color",
legend_group="hue", line_color="edge_color",
)
p.legend.location = legend_loc
p.legend.click_policy="hide"
show(p)
import torch
def get_sentence_embedding(model, tokenizer, sentence):
encoded = tokenizer.encode_plus(sentence, return_tensors="pt")
with torch.no_grad():
output = model(**encoded)
last_hidden_state = output.last_hidden_state
assert last_hidden_state.shape[0] == 1
assert last_hidden_state.shape[-1] == 768
# only pick the [CLS] token embedding (sentence embedding)
sentence_embedding = last_hidden_state[0, 0]
return sentence_embedding
def lighten_color(color, amount=0.5):
"""
Lightens the given color by multiplying (1-luminosity) by the given amount.
Input can be matplotlib color string, hex string, or RGB tuple.
Examples:
>> lighten_color('g', 0.3)
>> lighten_color('#F034A3', 0.6)
>> lighten_color((.3,.55,.1), 0.5)
"""
import matplotlib.colors as mc
import colorsys
try:
c = mc.cnames[color]
except:
c = color
c = colorsys.rgb_to_hls(*mc.to_rgb(c))
return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])
def plot_histogram(df, col, ax=None, color="blue", title=None, xlabel=None, **kwargs):
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=(5, 4))
ax.grid(alpha=0.3)
xlabel = col if xlabel is None else xlabel
ax.set_xlabel(xlabel)
ax.set_ylabel("Frequency")
title = f"Historgam of {col}" if title is None else title
ax.set_title(title)
label = f"Mean: {np.round(df[col].mean(), 1)}"
ax.hist(df[col].values, density=False, color=color, edgecolor=lighten_color(color, 0.1), label=label, **kwargs)
if "bins" in kwargs:
xticks = list(np.arange(kwargs["bins"])[::5])
xticks += list(np.linspace(xticks[-1], int(df[col].max()), 5, dtype=int))
# print(xticks)
ax.set_xticks(xticks)
ax.legend()
plt.show()
def beautify_ax(ax, title=None, titlesize=20, sizealpha=0.7, xlabel=None, ylabel=None):
labelsize = sizealpha * titlesize
ax.grid(alpha=0.3)
ax.set_xlabel(xlabel, fontsize=labelsize)
ax.set_ylabel(ylabel, fontsize=labelsize)
ax.set_title(title, fontsize=titlesize)
def get_text_features(text: list, model, device, batch_size=16):
import clip
text_batches = [text[i:i+batch_size] for i in range(0, len(text), batch_size)]
text_features = []
model = model.to(device)
model = model.eval()
for batch in tqdm(text_batches, desc="Getting text features", bar_format="{l_bar}{bar:20}{r_bar}"):
batch = clip.tokenize(batch).to(device)
with torch.no_grad():
batch_features = model.encode_text(batch)
text_features.append(batch_features.cpu().numpy())
text_features = np.concatenate(text_features, axis=0)
return text_features
from sklearn.manifold import TSNE
def reduce_dim(X, perplexity=30, n_iter=1000):
tsne = TSNE(
n_components=2,
perplexity=perplexity,
n_iter=n_iter,
init='pca',
# learning_rate="auto",
)
Z = tsne.fit_transform(X)
return Z
from IPython.display import Video
def show_video(video_path):
"""Show a video in a Jupyter notebook"""
assert exists(video_path), f"Video path {video_path} does not exist"
# display the video in a Jupyter notebook
return Video(video_path, embed=True, width=480)
# Video(video_path, embed=True, width=600, height=400)
# html_attributes="controls autoplay loop muted"
def show_single_audio(filepath=None, data=None, rate=None, start=None, end=None, label="Sample audio"):
if filepath is None:
assert data is not None and rate is not None, "Either filepath or data and rate must be provided"
args = dict(data=data, rate=rate)
else:
assert data is None and rate is None, "Either filepath or data and rate must be provided"
data, rate = librosa.load(filepath)
# args = dict(filename=filepath)
args = dict(data=data, rate=rate)
if start is not None and end is not None:
start = max(int(start * rate), 0)
end = min(int(end * rate), len(data))
else:
start = 0
end = len(data)
data = data[start:end]
args["data"] = data
if label is None:
label = "Sample audio"
label = Label(f"{label}")
out = widgets.Output()
with out:
display(Audio(**args))
vbox = VBox([label, out])
return vbox
def show_single_audio_with_spectrogram(filepath=None, data=None, rate=None, label="Sample audio", figsize=(6, 2)):
if filepath is None:
assert data is not None and rate is not None, "Either filepath or data and rate must be provided"
else:
data, rate = librosa.load(filepath)
# Show audio
vbox = show_single_audio(data=data, rate=rate, label=label)
# get width of audio widget
width = vbox.children[1].layout.width
# Show spectrogram
spec_out = widgets.Output()
D = librosa.stft(data) # STFT of y
S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
with spec_out:
fig, ax = plt.subplots(figsize=figsize)
img = librosa.display.specshow(
S_db,
ax=ax,
x_axis='time',
# y_axis='linear',
)
# img = widgets.Image.from_file(fig)
# import ipdb; ipdb.set_trace()
# img = widgets.Image(img)
# add image to vbox
vbox.children += (spec_out,)
return vbox
def show_spectrogram(audio_path=None, data=None, rate=None, figsize=(6, 2), ax=None, show=True):
if data is None and rate is None:
# Show spectrogram
data, rate = librosa.load(audio_path)
else:
assert audio_path is None, "Either audio_path or data and rate must be provided"
hop_length = 512
D = librosa.stft(data, n_fft=2048, hop_length=hop_length, win_length=2048) # STFT of y
S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
# Create spectrogram plot widget
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)
im = ax.imshow(S_db, origin='lower', aspect='auto', cmap='inferno')
# Replace xtixks with time
xticks = ax.get_xticks()
time_in_seconds = librosa.frames_to_time(xticks, sr=rate, hop_length=hop_length)
ax.set_xticklabels(np.round(time_in_seconds, 1))
ax.set_xlabel('Time')
ax.set_yticks([])
if ax is None:
plt.close(fig)
# Create widget output
spec_out = widgets.Output()
with spec_out:
display(fig)
return spec_out
def show_single_video_and_spectrogram(
video_path, audio_path,
label="Sample video", figsize=(6, 2),
width=480,
show_spec_stats=False,
):
# Show video
vbox = show_single_video(video_path, label=label, width=width)
# get width of video widget
width = vbox.children[1].layout.width
# Show spectrogram
data, rate = librosa.load(audio_path)
hop_length = 512
D = librosa.stft(data, n_fft=2048, hop_length=hop_length, win_length=2048) # STFT of y
S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
# Create spectrogram plot widget
fig, ax = plt.subplots(1, 1, figsize=figsize)
im = ax.imshow(S_db, origin='lower', aspect='auto', cmap='inferno')
# Replace xtixks with time
xticks = ax.get_xticks()
time_in_seconds = librosa.frames_to_time(xticks, sr=rate, hop_length=hop_length)
ax.set_xticklabels(np.round(time_in_seconds, 1))
ax.set_xlabel('Time')
ax.set_yticks([])
plt.close(fig)
# Create widget output
spec_out = widgets.Output()
with spec_out:
display(fig)
vbox.children += (spec_out,)
if show_spec_stats:
# Compute mean of spectrogram over frequency axis
eps = 1e-5
S_db_normalized = (S_db - S_db.mean(axis=1)[:, None]) / (S_db.std(axis=1)[:, None] + eps)
S_db_over_time = S_db_normalized.sum(axis=0)
# Plot S_db_over_time
fig, ax = plt.subplots(1, 1, figsize=(6, 2))
# ax.set_title("Spectrogram over time")
ax.grid(alpha=0.5)
x = np.arange(len(S_db_over_time))
x = librosa.frames_to_time(x, sr=rate, hop_length=hop_length)
x = np.round(x, 1)
ax.plot(x, S_db_over_time)
ax.set_xlabel('Time')
ax.set_yticks([])
plt.close(fig)
plot_out = widgets.Output()
with plot_out:
display(fig)
vbox.children += (plot_out,)
return vbox
def show_single_spectrogram(
filepath=None,
data=None,
rate=None,
start=None,
end=None,
ax=None,
label="Sample spectrogram",
figsize=(6, 2),
xlabel="Time",
):
if filepath is None:
assert data is not None and rate is not None, "Either filepath or data and rate must be provided"
else:
rate = 22050
offset = start or 0
clip_duration = end - start if end is not None else None
data, rate = librosa.load(filepath, sr=rate, offset=offset, duration=clip_duration)
# start = 0 if start is None else int(rate * start)
# end = len(data) if end is None else int(rate * end)
# data = data[start:end]
# Show spectrogram
spec_out = widgets.Output()
D = librosa.stft(data) # STFT of y
S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
with spec_out:
img = librosa.display.specshow(
S_db,
ax=ax,
x_axis='time',
sr=rate,
# y_axis='linear',
)
ax.set_xlabel(xlabel)
ax.margins(x=0)
plt.subplots_adjust(wspace=0, hspace=0)
# img = widgets.Image.from_file(fig)
# import ipdb; ipdb.set_trace()
# img = widgets.Image(img)
# add image to vbox
vbox = VBox([spec_out])
return vbox
# return spec_out
# from decord import VideoReader
def show_single_video(filepath, label="Sample video", width=480, fix_resolution=True):
if label is None:
label = "Sample video"
height = None
if fix_resolution:
aspect_ratio = 16. / 9.
height = int(width * (1/ aspect_ratio))
label = Label(f"{label}")
out = widgets.Output()
with out:
display(Video(filepath, embed=True, width=width, height=height))
vbox = VBox([label, out])
return vbox
def show_grid_of_audio(files, starts=None, ends=None, labels=None, ncols=None, show_spec=False):
for f in files:
assert os.path.exists(f), f"File {f} does not exist."
if labels is None:
labels = [None] * len(files)
if starts is None:
starts = [None] * len(files)
if ends is None:
ends = [None] * len(files)
assert len(files) == len(labels)
if ncols is None:
ncols = 3
nfiles = len(files)
nrows = nfiles // ncols + (nfiles % ncols != 0)
# print(nrows, ncols)
for i in range(nrows):
row_hbox = []
for j in range(ncols):
idx = i * ncols + j
# print(i, j, idx)
if idx < len(files):
file, label = files[idx], labels[idx]
start, end = starts[idx], ends[idx]
vbox = show_single_audio(
filepath=file, label=label, start=start, end=end
)
if show_spec:
spec_box = show_spectrogram(file, figsize=(3.6, 1))
# Add spectrogram to vbox
vbox.children += (spec_box,)
# if not show_spec:
# vbox = show_single_audio(
# filepath=file, label=label, start=start, end=end
# )
# else:
# vbox = show_single_audio_with_spectrogram(
# filepath=file, label=label
# )
row_hbox.append(vbox)
row_hbox = HBox(row_hbox)
display(row_hbox)
def show_grid_of_videos(
files,
cut=False,
starts=None,
ends=None,
labels=None,
ncols=None,
width_overflow=False,
show_spec=False,
width_of_screen=1000,
):
from moviepy.editor import VideoFileClip
for f in files:
assert os.path.exists(f), f"File {f} does not exist."
if labels is None:
labels = [None] * len(files)
if starts is not None and ends is not None:
cut = True
if starts is None:
starts = [None] * len(files)
if ends is None:
ends = [None] * len(files)
assert len(files) == len(labels) == len(starts) == len(ends)
# cut the videos to the specified duration
if cut:
cut_files = []
for i, f in enumerate(files):
start, end = starts[i], ends[i]
tmp_f = os.path.join(os.path.expanduser("~"), f"tmp/clip_{i}.mp4")
cut_files.append(tmp_f)
video = VideoFileClip(f)
start = 0 if start is None else start
end = video.duration-1 if end is None else end
# print(start, end)
video.subclip(start, end).write_videofile(tmp_f, logger=None, verbose=False)
files = cut_files
if ncols is None:
ncols = 3
width_of_screen = 1000
# get width of the whole display screen
if not width_overflow:
width_of_single_video = width_of_screen // ncols
else:
width_of_single_video = 280
nfiles = len(files)
nrows = nfiles // ncols + (nfiles % ncols != 0)
# print(nrows, ncols)
for i in range(nrows):
row_hbox = []
for j in range(ncols):
idx = i * ncols + j
# print(i, j, idx)
if idx < len(files):
file, label = files[idx], labels[idx]
if not show_spec:
vbox = show_single_video(file, label, width_of_single_video)
else:
vbox = show_single_video_and_spectrogram(file, file, width=width_of_single_video, label=label)
row_hbox.append(vbox)
row_hbox = HBox(row_hbox)
display(row_hbox)
def preview_video(fp, label="Sample video frames", mode="uniform", frames_to_show=6):
from decord import VideoReader
assert exists(fp), f"Video does not exist at {fp}"
vr = VideoReader(fp)
nfs = len(vr)
fps = vr.get_avg_fps()
dur = nfs / fps
if mode == "all":
frame_indices = np.arange(nfs)
elif mode == "uniform":
frame_indices = np.linspace(0, nfs - 1, frames_to_show, dtype=int)
elif mode == "random":
frame_indices = np.random.randint(0, nfs - 1, replace=False)
frame_indices = sorted(frame_indices)
else:
raise ValueError(f"Unknown frame viewing mode {mode}.")
# Show grid of image
images = vr.get_batch(frame_indices).asnumpy()
show_grid_of_images(images, n_cols=len(frame_indices), title=label, figsize=(12, 2.3), titlesize=10)
def preview_multiple_videos(fps, labels, mode="uniform", frames_to_show=6):
for fp in fps:
assert exists(fp), f"Video does not exist at {fp}"
for fp, label in zip(fps, labels):
preview_video(fp, label, mode=mode, frames_to_show=frames_to_show)
def show_small_clips_in_a_video(
video_path,
clip_segments: list,
width=360,
labels=None,
show_spec=False,
resize=False,
):
from moviepy.editor import VideoFileClip
from ipywidgets import Layout
video = VideoFileClip(video_path)
if resize:
# Resize the video
print("Resizing the video to width", width)
video = video.resize(width=width)
if labels is None:
labels = [
f"Clip {i+1} [{clip_segments[i][0]} : {clip_segments[i][1]}]" for i in range(len(clip_segments))
]
else:
assert len(labels) == len(clip_segments)
tmp_dir = os.path.join(os.path.expanduser("~"), "tmp")
tmp_clippaths = [f"{tmp_dir}/clip_{i}.mp4" for i in range(len(clip_segments))]
iterator = tqdm_iterator(zip(clip_segments, tmp_clippaths), total=len(clip_segments), desc="Preparing clips")
clips = [
video.subclip(x, y).write_videofile(f, logger=None, verbose=False) \
for (x, y), f in iterator
]
# show_grid_of_videos(tmp_clippaths, labels, ncols=len(clips), width_overflow=True)
hbox = []
for i in range(len(clips)):
# vbox = show_single_video(tmp_clippaths[i], labels[i], width=280)
vbox = widgets.Output()
with vbox:
if show_spec:
display(
show_single_video_and_spectrogram(
tmp_clippaths[i], tmp_clippaths[i],
width=width, figsize=(4.4, 1.5),
)
)
else:
display(Video(tmp_clippaths[i], embed=True, width=width))
# reduce vspace between video and label
display(Label(labels[i], layout=Layout(margin="-8px 0px 0px 0px")))
# if show_spec:
# display(show_single_spectrogram(tmp_clippaths[i], figsize=(4.5, 1.5)))
hbox.append(vbox)
hbox = HBox(hbox)
display(hbox)
def show_single_video_and_audio(
video_path, audio_path, label="Sample video and audio",
start=None, end=None, width=360, sr=44100, show=True,
):
from moviepy.editor import VideoFileClip
# Load video
video = VideoFileClip(video_path)
video_args = {"embed": True, "width": width}
filepath = video_path
# Load audio
audio_waveform, sr = librosa.load(audio_path, sr=sr)
audio_args = {"data": audio_waveform, "rate": sr}
if start is not None and end is not None:
# Cut video from start to end
tmp_dir = os.path.join(os.path.expanduser("~"), "tmp")
clip_path = os.path.join(tmp_dir, "clip_sample.mp4")
video.subclip(start, end).write_videofile(clip_path, logger=None, verbose=False)
filepath = clip_path
# Cut audio from start to end
audio_waveform = audio_waveform[int(start * sr): int(end * sr)]
audio_args["data"] = audio_waveform
out = widgets.Output()
with out:
label = f"{label} [{start} : {end}]"
display(Label(label))
display(Video(filepath, **video_args))
display(Audio(**audio_args))
if show:
display(out)
else:
return out
def plot_waveform(waveform, sample_rate, figsize=(10, 2), ax=None, skip=100, show=True, title=None):
if isinstance(waveform, torch.Tensor):
waveform = waveform.numpy()
time_axis = torch.arange(0, len(waveform)) / sample_rate
waveform = waveform[::skip]
time_axis = time_axis[::skip]
if len(waveform.shape) == 1:
num_channels = 1
num_frames = waveform.shape[0]
waveform = waveform.reshape(1, num_frames)
elif len(waveform.shape) == 2:
num_channels, num_frames = waveform.shape
else:
raise ValueError(f"Waveform has invalid shape {waveform.shape}")
if ax is None:
figure, axes = plt.subplots(num_channels, 1, figsize=figsize)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].plot(time_axis, waveform[c], linewidth=1)
axes[c].grid(True)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
figure.suptitle(title)
else:
assert num_channels == 1
ax.plot(time_axis, waveform[0], linewidth=1)
ax.grid(True)
# ax.set_xticks([])
# ax.set_yticks([])
# ax.set_xlim(-0.1, 0.1)
ax.set_ylim(-0.05, 0.05)
if show:
plt.show(block=False)
def show_waveform_as_image(waveform, sr=16000):
"""Plots a waveform as plt fig and converts into PIL.Image"""
fig, ax = plt.subplots(figsize=(10, 2))
plot_waveform(waveform, sr, ax=ax, show=False)
fig.canvas.draw()
img = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
plt.close(fig)
return img
def plot_raw_audio_signal_with_markings(signal: np.ndarray, markings: list,
title: str = 'Raw audio signal with markings',
figsize: tuple = (23, 4),
):
plt.figure(figsize=figsize)
plt.grid()
plt.plot(signal)
for value in markings:
plt.axvline(x=value, c='red')
plt.xlabel('Time')
plt.title(title)
plt.show()
plt.close()
def get_concat_h(im1, im2):
"""Concatenate two images horizontally"""
dst = Image.new('RGB', (im1.width + im2.width, im1.height))
dst.paste(im1, (0, 0))
dst.paste(im2, (im1.width, 0))
return dst
def concat_images(images):
im1 = images[0]
canvas_height = max([im.height for im in images])
dst = Image.new('RGB', (sum([im.width for im in images]), im1.height))
start_width = 0
for i, im in enumerate(images):
if im.height < canvas_height:
start_height = (canvas_height - im.height) // 2
else:
start_height = 0
print(i, start_height)
dst.paste(im, (start_width, start_height))
start_width += im.width
return dst
def concat_images_with_border(images, border_width=5, border_color="white"):
im1 = images[0]
total_width = sum([im.width for im in images]) + (len(images) - 1) * border_width
max_height = max([im.height for im in images])
dst = Image.new(
'RGB',
(total_width, max_height),
border_color,
)
start_width = 0
uniform_height = im1.height
canvas_height = max([im.height for im in images])
for i, im in enumerate(images):
# if im.height != uniform_height:
# im = resize_height(im.copy(), uniform_height)
if im.height < canvas_height:
start_height = (canvas_height - im.height) // 2
# Pad with zeros at top and bottom
im = ImageOps.expand(
im, border=(0, start_height, 0, canvas_height - im.height - start_height),
)
start_height = 0
else:
start_height = 0
dst.paste(im, (start_width, start_height))
start_width += im.width + border_width
return dst
def concat_images_vertically(images):
im1 = images[0]
dst = Image.new('RGB', (im1.width, sum([im.height for im in images])))
start_height = 0
for i, im in enumerate(images):
dst.paste(im, (0, start_height))
start_height += im.height
return dst
def concat_images_vertically_with_border(images, border_width=5, border_color="white"):
im1 = images[0]
dst = Image.new('RGB', (im1.width, sum([im.height for im in images]) + (len(images) - 1) * border_width), border_color)
start_height = 0
for i, im in enumerate(images):
dst.paste(im, (0, start_height))
start_height += im.height + border_width
return dst
def get_concat_v(im1, im2):
"""Concatenate two images vertically"""
dst = Image.new('RGB', (im1.width, im1.height + im2.height))
dst.paste(im1, (0, 0))
dst.paste(im2, (0, im1.height))
return dst
def set_latex_fonts(usetex=True, fontsize=14, show_sample=False, **kwargs):
try:
plt.rcParams.update({
"text.usetex": usetex,
"font.family": "serif",
"font.serif": ["Computer Modern Roman"],
"font.size": fontsize,
**kwargs,
})
if show_sample:
plt.figure()
plt.title("Sample $y = x^2$")
plt.plot(np.arange(0, 10), np.arange(0, 10)**2, "--o")
plt.grid()
plt.show()
except:
print("Failed to setup LaTeX fonts. Proceeding without.")
pass
def get_colors(num_colors, palette="jet"):
cmap = plt.get_cmap(palette)
colors = [cmap(i) for i in np.linspace(0, 1, num_colors)]
return colors
def add_box_on_image(image, bbox, color="red", thickness=3, resized=False, fillcolor=None, fillalpha=0.2):
"""
Adds bounding box on image.
Args:
image (PIL.Image): image
bbox (list): [xmin, ymin, xmax, ymax]
color: -
thickness: -
"""
image = image.copy().convert("RGB")
# color = get_predominant_color(color)
color = PIL.ImageColor.getrgb(color)
# Apply alpha to fillcolor
if fillcolor is not None:
if isinstance(fillcolor, str):
fillcolor = PIL.ImageColor.getrgb(fillcolor)
fillcolor= fillcolor + (int(fillalpha * 255),)
elif isinstance(fillcolor, tuple):
if len(fillcolor) == 3:
fillcolor= fillcolor + (int(fillalpha * 255),)
else:
pass
# Create an instance of the ImageDraw class
draw = ImageDraw.Draw(image, "RGBA")
# Draw the bounding box on the image
draw.rectangle(bbox, outline=color, width=thickness, fill=fillcolor)
# Resize
new_width, new_height = (320, 240)
if resized:
image = image.resize((new_width, new_height))
return image
def add_multiple_boxes_on_image(image, bboxes, colors=None, thickness=3, resized=False, fillcolor=None, fillalpha=0.2):
image = image.copy().convert("RGB")
if colors is None:
colors = ["red"] * len(bboxes)
for bbox, color in zip(bboxes, colors):
image = add_box_on_image(image, bbox, color, thickness, resized, fillcolor, fillalpha)
return image
def colorize_mask(mask, color="red"):
# mask = mask.convert("RGBA")
color = PIL.ImageColor.getrgb(color)
mask = ImageOps.colorize(mask, (0, 0, 0, 0), color)
return mask
def add_mask_on_image(image: Image, mask: Image, color="green", alpha=0.5):
image = image.copy()
mask = mask.copy()
# get color if it is a string
if isinstance(color, str):
color = PIL.ImageColor.getrgb(color)
# color = get_predominant_color(color)
mask = ImageOps.colorize(mask, (0, 0, 0, 0), color)
mask = mask.convert("RGB")
assert (mask.size == image.size)
assert (mask.mode == image.mode)
# Blend the original image and the segmentation mask with a 50% weight
blended_image = Image.blend(image, mask, alpha)
return blended_image
def blend_images(img1, img2, alpha=0.5):
# Convert images to RGBA
img1 = img1.convert("RGBA")
img2 = img2.convert("RGBA")
alpha_blended = Image.blend(img1, img2, alpha=alpha)
# Convert back to RGB
alpha_blended = alpha_blended.convert("RGB")
return alpha_blended
def visualize_youtube_clip(
youtube_id, st, et, label="",
show_spec=False,
video_width=360, video_height=240,
):
url = f"https://www.youtube.com/embed/{youtube_id}?start={int(st)}&end={int(et)}"
video_html_code = f"""
"""
label_html_code = f"""Caption: {label}
Time: {st} to {et}"""
# Show label and video below it
label = widgets.HTML(label_html_code)
video = widgets.HTML(video_html_code)
if show_spec:
import pytube
import base64
from io import BytesIO
from moviepy.video.io.VideoFileClip import VideoFileClip
from moviepy.audio.io.AudioFileClip import AudioFileClip
# Load audio directly from youtube
video_url = f"https://www.youtube.com/watch?v={youtube_id}"
yt = pytube.YouTube(video_url)
# Get the audio stream
audio_stream = yt.streams.filter(only_audio=True).first()
# Download audio stream
# audio_file = os.path.join("/tmp", "sample_audio.mp3")
audio_stream.download(output_path='/tmp', filename='sample.mp4')
audio_clip = AudioFileClip("/tmp/sample.mp4")
audio_subclip = audio_clip.subclip(st, et)
sr = audio_subclip.fps
y = audio_subclip.to_soundarray().mean(axis=1)
audio_subclip.close()
audio_clip.close()
# Compute spectrogram in librosa
S_db = librosa.power_to_db(librosa.feature.melspectrogram(y, sr=sr), ref=np.max)
# Compute width in cms from video_width
width = video_width / plt.rcParams["figure.dpi"] + 0.63
height = video_height / plt.rcParams["figure.dpi"]
out = widgets.Output()
with out:
fig, ax = plt.subplots(figsize=(width, height))
librosa.display.specshow(S_db, sr=sr, x_axis='time', ax=ax)
ax.set_ylabel("Frequency (Hz)")
else:
out = widgets.Output()
vbox = widgets.VBox([label, video, out])
return vbox
def visualize_pair_of_youtube_clips(clip_a, clip_b):
yt_id_a = clip_a["youtube_id"]
label_a = clip_a["sentence"]
st_a, et_a = clip_a["time"]
yt_id_b = clip_b["youtube_id"]
label_b = clip_b["sentence"]
st_b, et_b = clip_b["time"]
# Show the clips side by side
clip_a = visualize_youtube_clip(yt_id_a, st_a, et_a, label_a, show_spec=True)
# clip_a = widgets.Output()
# with clip_a:
# visualize_youtube_clip(yt_id_a, st_a, et_a, label_a, show_spec=True)
clip_b = visualize_youtube_clip(yt_id_b, st_b, et_b, label_b, show_spec=True)
# clip_b = widgets.Output()
# with clip_b:
# visualize_youtube_clip(yt_id_b, st_b, et_b, label_b, show_spec=True)
hbox = HBox([
clip_a, clip_b
])
display(hbox)
def plot_1d(x: np.ndarray, figsize=(6, 2), title=None, xlabel=None, ylabel=None, show=True, **kwargs):
assert (x.ndim == 1)
fig, ax = plt.subplots(figsize=figsize)
ax.grid(alpha=0.3)
ax.set_title(title)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.plot(np.arange(len(x)), x, **kwargs)
if show:
plt.show()
else:
plt.close()
return fig
def make_grid(cols,rows):
import streamlit as st
grid = [0]*cols
for i in range(cols):
with st.container():
grid[i] = st.columns(rows)
return grid
def display_clip(video_path, stime, etime, label=None):
"""Displays clip at index i."""
assert exists(video_path), f"Video does not exist at {video_path}"
display(
show_small_clips_in_a_video(
video_path, [(stime, etime)], labels=[label],
),
)
def countplot(df, column, title=None, rotation=90, ylabel="Count", figsize=(8, 5), ax=None, show=True, show_counts=False):
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
ax.grid(alpha=0.4)
ax.set_xlabel(column)
ax.set_ylabel(ylabel)
ax.set_title(title)
data = dict(df[column].value_counts())
# Extract keys and values from the dictionary
categories = list(data.keys())
counts = list(data.values())
# Create a countplot
ax.bar(categories, counts)
ax.set_xticklabels(categories, rotation=rotation)
# Show count values on top of bars
if show_counts:
max_v = max(counts)
for i, v in enumerate(counts):
delta = 0.01 * max_v
ax.text(i, v + delta, str(v), ha="center")
if show:
plt.show()
def get_linspace_colors(cmap_name='viridis', num_colors = 10):
import matplotlib.colors as mcolors
# Get the colormap object
cmap = plt.cm.get_cmap(cmap_name)
# Get the evenly spaced indices
indices = np.arange(0, 1, 1./num_colors)
# Get the corresponding colors from the colormap
colors = [mcolors.to_hex(cmap(idx)) for idx in indices]
return colors
def hex_to_rgb(colors):
from PIL import ImageColor
return [ImageColor.getcolor(c, "RGB") for c in colors]
def plot_audio_feature(times, feature, feature_label="Feature", xlabel="Time", figsize=(20, 2)):
fig, ax = plt.subplots(1, 1, figsize=figsize)
ax.grid(alpha=0.4)
ax.set_xlabel(xlabel)
ax.set_ylabel(feature_label)
ax.set_yticks([])
ax.plot(times, feature, '--', linewidth=0.5)
plt.show()
def compute_rms(y, frame_length=512):
rms = librosa.feature.rms(y=y, frame_length=frame_length)[0]
times = librosa.samples_to_time(frame_length * np.arange(len(rms)))
return times, rms
def plot_audio_features(path, label, show=True, show_video=True, features=["rms"], frame_length=512, figsize=(5, 2), return_features=False):
# Load audio
y, sr = librosa.load(path)
# Show video
if show_video:
if show:
display(
show_single_video_and_spectrogram(
path, path, label=label, figsize=figsize,
width=410,
)
)
else:
if show:
# Show audio and spectrogram
display(
show_single_audio_with_spectrogram(path, label=label, figsize=figsize)
)
feature_data = dict()
for f in features:
fn = eval(f"compute_{f}")
args = dict(y=y, frame_length=frame_length)
xvals, yvals = fn(**args)
feature_data[f] = (xvals, yvals)
if show:
display(
plot_audio_feature(
xvals, yvals, feature_label=f.upper(), figsize=(figsize[0] - 0.25, figsize[1]),
)
)
if return_features:
return feature_data
def rescale_frame(frame, scale=1.):
"""Rescales a frame by a factor of scale."""
return frame.resize((int(frame.width * scale), int(frame.height * scale)))
def save_gif(images, path, duration=None, fps=30):
import imageio
images = [np.asarray(image) for image in images]
if fps is not None:
imageio.mimsave(path, images, fps=fps)
else:
assert duration is not None
imageio.mimsave(path, images, duration=duration)
def show_subsampled_frames(frames, n_show, figsize=(15, 3), as_canvas=True):
indices = np.arange(len(frames))
indices = np.linspace(0, len(frames) - 1, n_show, dtype=int)
show_frames = [frames[i] for i in indices]
if as_canvas:
return concat_images(show_frames)
else:
show_grid_of_images(show_frames, n_cols=n_show, figsize=figsize, subtitles=indices)
def tensor_to_heatmap(x, scale=True, cmap="viridis", flip_vertically=False):
import PIL
if isinstance(x, torch.Tensor):
x = x.numpy()
if scale:
x = (x - x.min()) / (x.max() - x.min())
cm = plt.get_cmap(cmap)
if flip_vertically:
x = np.flip(x, axis=0) # put low frequencies at the bottom in image
x = cm(x)
x = (x * 255).astype(np.uint8)
if x.shape[-1] == 3:
x = PIL.Image.fromarray(x, mode="RGB")
elif x.shape[-1] == 4:
x = PIL.Image.fromarray(x, mode="RGBA").convert("RGB")
else:
raise ValueError(f"Invalid shape {x.shape}")
return x
def batch_tensor_to_heatmap(x, scale=True, cmap="viridis", flip_vertically=False, resize=None):
y = []
for i in range(len(x)):
h = tensor_to_heatmap(x[i], scale, cmap, flip_vertically)
if resize is not None:
h = h.resize(resize)
y.append(h)
return y
def change_contrast(img, level):
factor = (259 * (level + 255)) / (255 * (259 - level))
def contrast(c):
return 128 + factor * (c - 128)
return img.point(contrast)
def change_brightness(img, alpha):
import PIL
enhancer = PIL.ImageEnhance.Brightness(img)
# to reduce brightness by 50%, use factor 0.5
img = enhancer.enhance(alpha)
return img
def draw_horizontal_lines(image, y_values, color=(255, 0, 0), colors=None, line_thickness=2):
"""
Draw horizontal lines on a PIL image at specified Y positions.
Args:
image (PIL.Image.Image): The input PIL image.
y_values (list or int): List of Y positions where lines will be drawn.
If a single integer is provided, a line will be drawn at that Y position.
color (tuple): RGB color tuple (e.g., (255, 0, 0) for red).
line_thickness (int): Thickness of the lines.
Returns:
PIL.Image.Image: The PIL image with the drawn lines.
"""
image = image.copy()
if isinstance(color, str):
color = PIL.ImageColor.getcolor(color, "RGB")
if colors is None:
colors = [color] * len(y_values)
else:
if isinstance(colors[0], str):
colors = [PIL.ImageColor.getcolor(c, "RGB") for c in colors]
if isinstance(y_values, int):
y_values = [y_values]
# Create a drawing context on the image
draw = PIL.ImageDraw.Draw(image)
if isinstance(y_values, int):
y_values = [y_values]
for y, c in zip(y_values, colors):
draw.line([(0, y), (image.width, y)], fill=c, width=line_thickness)
return image
def draw_vertical_lines(image, x_values, color=(255, 0, 0), colors=None, line_thickness=2):
"""
Draw vertical lines on a PIL image at specified X positions.
Args:
image (PIL.Image.Image): The input PIL image.
x_values (list or int): List of X positions where lines will be drawn.
If a single integer is provided, a line will be drawn at that X position.
color (tuple): RGB color tuple (e.g., (255, 0, 0) for red).
line_thickness (int): Thickness of the lines.
Returns:
PIL.Image.Image: The PIL image with the drawn lines.
"""
image = image.copy()
if isinstance(color, str):
color = PIL.ImageColor.getcolor(color, "RGB")
if colors is None:
colors = [color] * len(x_values)
else:
if isinstance(colors[0], str):
colors = [PIL.ImageColor.getcolor(c, "RGB") for c in colors]
if isinstance(x_values, int):
x_values = [x_values]
# Create a drawing context on the image
draw = PIL.ImageDraw.Draw(image)
if isinstance(x_values, int):
x_values = [x_values]
for x, c in zip(x_values, colors):
draw.line([(x, 0), (x, image.height)], fill=c, width=line_thickness)
return image
def show_arrow_on_image(image, start_loc, end_loc, color="red", thickness=3):
"""Draw a line on PIL image from start_loc to end_loc."""
image = image.copy()
color = get_predominant_color(color)
# Create an instance of the ImageDraw class
draw = ImageDraw.Draw(image)
# Draw the bounding box on the image
draw.line([start_loc, end_loc], fill=color, width=thickness)
return image
def draw_arrow_on_image_cv2(image, start_loc, end_loc, color="red", thickness=2, both_ends=False):
image = image.copy()
image = np.asarray(image)
if isinstance(color, str):
color = PIL.ImageColor.getcolor(color, "RGB")
image = cv2.arrowedLine(image, start_loc, end_loc, color, thickness)
if both_ends:
image = cv2.arrowedLine(image, end_loc, start_loc, color, thickness)
return PIL.Image.fromarray(image)
def draw_arrow_with_text(image, start_loc, end_loc, text="", color="red", thickness=2, font_size=20, both_ends=False, delta=5):
image = np.asarray(image)
if isinstance(color, str):
color = PIL.ImageColor.getcolor(color, "RGB")
# Calculate the center point between start_loc and end_loc
center_x = (start_loc[0] + end_loc[0]) // 2
center_y = (start_loc[1] + end_loc[1]) // 2
center_point = (center_x, center_y)
# Draw the arrowed line
image = cv2.arrowedLine(image, start_loc, end_loc, color, thickness)
if both_ends:
image = cv2.arrowedLine(image, end_loc, start_loc, color, thickness)
# Create a PIL image from the NumPy array for drawing text
image_with_text = Image.fromarray(image)
draw = PIL.ImageDraw.Draw(image_with_text)
# Calculate the text size
# font = PIL.ImageFont.truetype("arial.ttf", font_size)
# This gives an error: "OSError: cannot open resource", as a hack, use the following
text_width, text_height = draw.textsize(text)
# Calculate the position to center the text
text_x = center_x - (text_width // 2) - delta
text_y = center_y - (text_height // 2)
# Draw the text
draw.text((text_x, text_y), text, color)
return image_with_text
def draw_arrowed_line(image, start_loc, end_loc, color="red", thickness=2):
"""
Draw an arrowed line on a PIL image from a starting point to an ending point.
Args:
image (PIL.Image.Image): The input PIL image.
start_loc (tuple): Starting point (x, y) for the arrowed line.
end_loc (tuple): Ending point (x, y) for the arrowed line.
color (str): Color of the line (e.g., 'red', 'green', 'blue').
thickness (int): Thickness of the line and arrowhead.
Returns:
PIL.Image.Image: The PIL image with the drawn arrowed line.
"""
image = image.copy()
if isinstance(color, str):
color = PIL.ImageColor.getcolor(color, "RGB")
# Create a drawing context on the image
draw = ImageDraw.Draw(image)
# Draw a line from start to end
draw.line([start_loc, end_loc], fill=color, width=thickness)
# Calculate arrowhead points
arrow_size = 10 # Size of the arrowhead
dx = end_loc[0] - start_loc[0]
dy = end_loc[1] - start_loc[1]
length = (dx ** 2 + dy ** 2) ** 0.5
cos_theta = dx / length
sin_theta = dy / length
x1 = end_loc[0] - arrow_size * cos_theta
y1 = end_loc[1] - arrow_size * sin_theta
x2 = end_loc[0] - arrow_size * sin_theta
y2 = end_loc[1] + arrow_size * cos_theta
x3 = end_loc[0] + arrow_size * sin_theta
y3 = end_loc[1] - arrow_size * cos_theta
# Draw the arrowhead triangle
draw.polygon([end_loc, (x1, y1), (x2, y2), (x3, y3)], fill=color)
return image
def center_crop_to_fraction(image, frac=0.5):
"""Center crop an image to a fraction of its original size."""
width, height = image.size
new_width = int(width * frac)
new_height = int(height * frac)
left = (width - new_width) // 2
top = (height - new_height) // 2
right = (width + new_width) // 2
bottom = (height + new_height) // 2
return image.crop((left, top, right, bottom))
def decord_load_frames(vr, frame_indices):
if isinstance(frame_indices, int):
frame_indices = [frame_indices]
frames = vr.get_batch(frame_indices).asnumpy()
frames = [Image.fromarray(frame) for frame in frames]
return frames
def paste_mask_on_image(original_image, bounding_box, mask):
"""
Paste a 2D mask onto the original image at the location specified by the bounding box.
Parameters:
- original_image (PIL.Image): The original image.
- bounding_box (tuple): Bounding box coordinates (left, top, right, bottom).
- mask (PIL.Image): The 2D mask.
Returns:
- PIL.Image: Image with the mask pasted on it.
Example:
```
original_image = Image.open('original.jpg')
bounding_box = (100, 100, 200, 200)
mask = Image.open('mask.png')
result_image = paste_mask_on_image(original_image, bounding_box, mask)
result_image.show()
```
"""
# Create a copy of the original image to avoid modifying the input image
result_image = original_image.copy()
# Crop the mask to the size of the bounding box
mask_cropped = mask.crop((0, 0, bounding_box[2] - bounding_box[0], bounding_box[3] - bounding_box[1]))
# Paste the cropped mask onto the original image at the specified location
result_image.paste(mask_cropped, (bounding_box[0], bounding_box[1]))
return result_image
def display_images_as_video_moviepy(image_list, fps=5, show=True):
"""
Display a list of PIL images as a video in Jupyter Notebook using MoviePy.
Parameters:
- image_list (list): List of PIL images.
- fps (int): Frames per second for the video.
- show (bool): Whether to display the video in the notebook.
Example:
```
image_list = [Image.open('frame1.jpg'), Image.open('frame2.jpg'), ...]
display_images_as_video_moviepy(image_list, fps=10)
```
"""
from IPython.display import display
from moviepy.editor import ImageSequenceClip
image_list = list(map(np.asarray, image_list))
clip = ImageSequenceClip(image_list, fps=fps)
if show:
display(clip.ipython_display(width=200))
os.remove("__temp__.mp4")
def resize_height(img, H):
w, h = img.size
asp_ratio = w / h
W = np.ceil(asp_ratio * H).astype(int)
return img.resize((W, H))
def resize_width(img, W):
w, h = img.size
asp_ratio = w / h
H = int(W / asp_ratio)
return img.resize((W, H))
def resized_minor_side(img, size=256):
H, W = img.size
if H < W:
H_new = size
W_new = int(size * W / H)
return img.resize((W_new, H_new))
else:
W_new = size
H_new = int(size * H / W)
return img.resize((W_new, H_new))
def brighten_image(img, alpha=1.2):
enhancer = PIL.ImageEnhance.Brightness(img)
img = enhancer.enhance(alpha)
return img
def darken_image(img, alpha=0.8):
enhancer = PIL.ImageEnhance.Brightness(img)
img = enhancer.enhance(alpha)
return img
def fig2img(fig):
"""Convert a Matplotlib figure to a PIL Image and return it"""
import io
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
def show_temporal_tsne(
tsne,
timestamps=None,
title="tSNE: feature vectors over time",
cmap='viridis',
ax=None,
fig=None,
show=True,
num_ticks=10,
return_as_pil=False,
dpi=100,
label='Time (s)',
figsize=(6, 4),
s=None,
):
if timestamps is None:
timestamps = np.arange(len(tsne))
if ax is None or fig is None:
fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
cmap = plt.get_cmap(cmap)
scatter = ax.scatter(
tsne[:, 0], tsne[:, 1], c=np.arange(len(tsne)), cmap=cmap, s=s,
edgecolor='k', linewidth=0.5,
)
ax.grid(alpha=0.4)
ax.set_title(f"{title}", fontsize=11)
ax.set_xlabel("$z_{1}$")
ax.set_ylabel("$z_{2}$")
# Create a colorbar
cbar = fig.colorbar(scatter, ax=ax, label=label)
# Set custom ticks and labels on the colorbar
ticks = np.linspace(0, len(tsne) - 1, num_ticks, dtype=int)
tick_labels = np.round(timestamps[ticks], 1)
cbar.set_ticks(ticks)
cbar.set_ticklabels(tick_labels)
if show:
plt.show()
else:
if return_as_pil:
plt.tight_layout(pad=0.2)
# fig.canvas.draw()
# image = PIL.Image.frombytes(
# 'RGB',
# fig.canvas.get_width_height(),
# fig.canvas.tostring_rgb(),
# )
# return image
# Return as PIL Image without displaying the plt figure
image = fig2img(fig)
plt.close(fig)
return image
def mark_keypoints(image, keypoints, color=(255, 255, 0), radius=1):
"""
Marks keypoints on an image with a given color and radius.
:param image: The input PIL image.
:param keypoints: A list of (x, y) tuples representing the keypoints.
:param color: The color to use for the keypoints (default: red).
:param radius: The radius of the circle to draw for each keypoint (default: 5).
:return: A new PIL image with the keypoints marked.
"""
# Make a copy of the image to avoid modifying the original
image_copy = image.copy()
# Create a draw object to add graphical elements
draw = ImageDraw.Draw(image_copy)
# Loop through each keypoint and draw a circle
for x, y in keypoints:
# Draw a circle with the specified radius and color
draw.ellipse(
(x - radius, y - radius, x + radius, y + radius),
fill=color,
width=2
)
return image_copy
def draw_line_on_image(image, x_coords, y_coords, color=(255, 255, 0), width=3):
"""
Draws a line on an image given lists of x and y coordinates.
:param image: The input PIL image.
:param x_coords: List of x-coordinates for the line.
:param y_coords: List of y-coordinates for the line.
:param color: Color of the line in RGB (default is red).
:param width: Width of the line (default is 3).
:return: The PIL image with the line drawn.
"""
image = image.copy()
# Ensure the number of x and y coordinates are the same
if len(x_coords) != len(y_coords):
raise ValueError("x_coords and y_coords must have the same length")
# Create a draw object to draw on the image
draw = ImageDraw.Draw(image)
# Create a list of (x, y) coordinate tuples
coordinates = list(zip(x_coords, y_coords))
# Draw the line connecting the coordinates
draw.line(coordinates, fill=color, width=width)
return image
def add_binary_strip_vertically(
image,
binary_vector,
strip_width=15,
one_color="yellow",
zero_color="gray",
):
"""
Add a binary strip to the right side of an image.
:param image: PIL Image to which the strip will be added.
:param binary_vector: Binary vector of length 512 representing the strip.
:param strip_width: Width of the strip to be added.
:param one_color: Color for "1" pixels (default: red).
:param zero_color: Color for "0" pixels (default: white).
:return: New image with the binary strip added on the right side.
"""
one_color = PIL.ImageColor.getrgb(one_color)
zero_color = PIL.ImageColor.getrgb(zero_color)
height = image.height
if len(binary_vector) != height:
raise ValueError("Binary vector must be of length 512")
# Create a new strip with the specified width and 512 height
strip = PIL.Image.new("RGB", (strip_width, height))
# Fill the strip based on the binary vector
pixels = strip.load()
for i in range(height):
color = one_color if binary_vector[i] == 1 else zero_color
for w in range(strip_width):
pixels[w, i] = color
# Combine the original image with the new strip
# new_image = PIL.Image.new("RGB", (image.width + strip_width, height))
# new_image.paste(image, (0, 0))
# new_image.paste(strip, (image.width, 0))
new_image = image.copy()
new_image.paste(strip, (image.width - strip_width, 0))
return new_image
def add_binary_strip_horizontally(
image,
binary_vector,
strip_height=15,
one_color="limegreen",
zero_color="gray",
):
"""
Add a binary strip to the top of an image.
:param image: PIL Image to which the strip will be added.
:param binary_vector: Binary vector of length 512 representing the strip.
:param strip_height: Height of the strip to be added.
:param one_color: Color for "1" pixels, accepts color names or hex (default: red).
:param zero_color: Color for "0" pixels, accepts color names or hex (default: white).
:return: New image with the binary strip added at the top.
"""
width = image.width
if len(binary_vector) != width:
raise ValueError("Binary vector must be of length 512")
# Convert colors to RGB tuples
one_color_rgb = PIL.ImageColor.getrgb(one_color)
zero_color_rgb = PIL.ImageColor.getrgb(zero_color)
# Create a new strip with the specified height and 512 width
strip = PIL.Image.new("RGB", (width, strip_height))
# Fill the strip based on the binary vector
pixels = strip.load()
for i in range(width):
color = one_color_rgb if binary_vector[i] == 1 else zero_color_rgb
for h in range(strip_height):
pixels[i, h] = color
# Combine the original image with the new strip
# new_image = PIL.Image.new("RGB", (width, image.height + strip_height))
# new_image.paste(strip, (0, 0))
# new_image.paste(image, (0, strip_height))
new_image = image.copy()
new_image.paste(strip, (0, 0))
return new_image
# Define a function to increase font sizes for a specific plot
def increase_font_sizes(ax, font_scale=1.6):
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
ax.get_xticklabels() + ax.get_yticklabels()):
item.set_fontsize(item.get_fontsize() * font_scale)
def cut_fraction_of_bbox(image, box, frac=0.7):
"""
Cuts the image such that the box occupies a fraction of the image.
"""
W, H = image.size
x1, y1, x2, y2 = box
w = x2 - x1
h = y2 - y1
new_w = int(w / frac)
new_h = int(h / frac)
x1_new = max(0, x1 - (new_w - w) // 2)
x2_new = min(W, x2 + (new_w - w) // 2)
y1_new = max(0, y1 - (new_h - h) // 2)
y2_new = min(H, y2 + (new_h - h) // 2)
return image.crop((x1_new, y1_new, x2_new, y2_new))