|
import os |
|
from typing import Optional |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
from lpips import LPIPS |
|
from PIL import Image |
|
from torchvision.transforms import Normalize |
|
|
|
|
|
def show_images_horizontally( |
|
list_of_files: np.array, output_file: Optional[str] = None, interact: bool = False |
|
) -> None: |
|
""" |
|
Visualize the list of images horizontally and save the figure as PNG. |
|
|
|
Args: |
|
list_of_files: The list of images as numpy array with shape (N, H, W, C). |
|
output_file: The output file path to save the figure as PNG. |
|
interact: Whether to show the figure interactively in Jupyter Notebook or not in Python. |
|
""" |
|
number_of_files = len(list_of_files) |
|
|
|
heights = [a[0].shape[0] for a in list_of_files] |
|
widths = [a.shape[1] for a in list_of_files[0]] |
|
|
|
fig_width = 8.0 |
|
fig_height = fig_width * sum(heights) / sum(widths) |
|
|
|
|
|
_, axs = plt.subplots( |
|
1, number_of_files, figsize=(fig_width * number_of_files, fig_height) |
|
) |
|
plt.tight_layout() |
|
for i in range(number_of_files): |
|
_image = list_of_files[i] |
|
axs[i].imshow(_image) |
|
axs[i].axis("off") |
|
|
|
|
|
if interact: |
|
plt.show() |
|
else: |
|
plt.savefig(output_file, bbox_inches="tight", pad_inches=0.25) |
|
|
|
|
|
def save_image(image: np.array, file_name: str) -> None: |
|
""" |
|
Save the image as JPG. |
|
|
|
Args: |
|
image: The input image as numpy array with shape (H, W, C). |
|
file_name: The file name to save the image. |
|
""" |
|
image = Image.fromarray(image) |
|
image.save(file_name) |
|
|
|
|
|
def load_and_process_images(load_dir: str) -> np.array: |
|
""" |
|
Load and process the images into numpy array from the directory. |
|
|
|
Args: |
|
load_dir: The directory to load the images. |
|
|
|
Returns: |
|
images: The images as numpy array with shape (N, H, W, C). |
|
""" |
|
images = [] |
|
print(load_dir) |
|
filenames = sorted( |
|
os.listdir(load_dir), key=lambda x: int(x.split(".")[0]) |
|
) |
|
for filename in filenames: |
|
if filename.endswith(".jpg"): |
|
img = Image.open(os.path.join(load_dir, filename)) |
|
img_array = ( |
|
np.asarray(img) / 255.0 |
|
) |
|
images.append(img_array) |
|
return images |
|
|
|
|
|
def compute_lpips(images: np.array, lpips_model: LPIPS) -> np.array: |
|
""" |
|
Compute the LPIPS of the input images. |
|
|
|
Args: |
|
images: The input images as numpy array with shape (N, H, W, C). |
|
lpips_model: The LPIPS model used to compute perceptual distances. |
|
|
|
Returns: |
|
distances: The LPIPS of the input images. |
|
""" |
|
|
|
device = next(lpips_model.parameters()).device |
|
device = str(device) |
|
|
|
|
|
images = torch.tensor(images).to(device).float() |
|
images = torch.permute(images, (0, 3, 1, 2)) |
|
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
images = normalize(images) |
|
|
|
|
|
distances = [] |
|
for i in range(images.shape[0]): |
|
if i == images.shape[0] - 1: |
|
break |
|
img1 = images[i].unsqueeze(0) |
|
img2 = images[i + 1].unsqueeze(0) |
|
loss = lpips_model(img1, img2) |
|
distances.append(loss.item()) |
|
distances = np.array(distances) |
|
return distances |
|
|
|
|
|
def compute_gini(distances: np.array) -> float: |
|
""" |
|
Compute the Gini index of the input distances. |
|
|
|
Args: |
|
distances: The input distances as numpy array. |
|
|
|
Returns: |
|
gini: The Gini index of the input distances. |
|
""" |
|
if len(distances) < 2: |
|
return 0.0 |
|
|
|
|
|
sorted_distances = sorted(distances) |
|
n = len(sorted_distances) |
|
mean_distance = sum(sorted_distances) / n |
|
|
|
|
|
sum_of_differences = 0 |
|
for di in sorted_distances: |
|
for dj in sorted_distances: |
|
sum_of_differences += abs(di - dj) |
|
|
|
|
|
gini = sum_of_differences / (2 * n * n * mean_distance) |
|
return gini |
|
|
|
|
|
def compute_smoothness_and_consistency(images: np.array, lpips_model: LPIPS) -> tuple: |
|
""" |
|
Compute the smoothness and efficiency of the input images. |
|
|
|
Args: |
|
images: The input images as numpy array with shape (N, H, W, C). |
|
lpips_model: The LPIPS model used to compute perceptual distances. |
|
|
|
Returns: |
|
smoothness: One minus gini index of LPIPS of consecutive images. |
|
consistency: The mean LPIPS of consecutive images. |
|
max_inception_distance: The maximum LPIPS of consecutive images. |
|
""" |
|
distances = compute_lpips(images, lpips_model) |
|
smoothness = 1 - compute_gini(distances) |
|
consistency = np.mean(distances) |
|
max_inception_distance = np.max(distances) |
|
return smoothness, consistency, max_inception_distance |
|
|
|
|
|
def separate_source_and_interpolated_images(images: np.array) -> tuple: |
|
""" |
|
Separate the input images into source and interpolated images. |
|
The input source is the start and end of the images, while the interpolated images are the rest. |
|
|
|
Args: |
|
images: The input images as numpy array with shape (N, H, W, C). |
|
|
|
Returns: |
|
source: The source images as numpy array with shape (2, H, W, C). |
|
interpolation: The interpolated images as numpy array with shape (N-2, H, W, C). |
|
""" |
|
|
|
if len(images) < 2: |
|
raise ValueError("The input array should have at least two elements.") |
|
|
|
|
|
|
|
source = np.array([images[0], images[-1]]) |
|
|
|
interpolation = images[1:-1] |
|
return source, interpolation |
|
|