Spaces:
Runtime error
Runtime error
import random | |
from typing import Any, Optional | |
import numpy as np | |
import os | |
import cv2 | |
from glob import glob | |
from PIL import Image, ImageDraw | |
from tqdm import tqdm | |
import kornia | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import albumentations as albu | |
import functools | |
import math | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
import torchvision as tv | |
import torchvision.models as models | |
from torchvision import transforms | |
from torchvision.transforms import functional as F | |
from losses import TempCombLoss | |
######## for loading checkpoint from googledrive | |
google_drive_paths = { | |
"BayesCap_SRGAN.pth": "https://drive.google.com/uc?id=1d_5j1f8-vN79htZTfRUqP1ddHZIYsNvL", | |
"BayesCap_ckpt.pth": "https://drive.google.com/uc?id=1Vg1r6gKgQ1J3M51n6BeKXYS8auT9NhA9", | |
} | |
def ensure_checkpoint_exists(model_weights_filename): | |
if not os.path.isfile(model_weights_filename) and ( | |
model_weights_filename in google_drive_paths | |
): | |
gdrive_url = google_drive_paths[model_weights_filename] | |
try: | |
from gdown import download as drive_download | |
drive_download(gdrive_url, model_weights_filename, quiet=False) | |
except ModuleNotFoundError: | |
print( | |
"gdown module not found.", | |
"pip3 install gdown or, manually download the checkpoint file:", | |
gdrive_url | |
) | |
if not os.path.isfile(model_weights_filename) and ( | |
model_weights_filename not in google_drive_paths | |
): | |
print( | |
model_weights_filename, | |
" not found, you may need to manually download the model weights." | |
) | |
def normalize(image: np.ndarray) -> np.ndarray: | |
"""Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data. | |
Args: | |
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``. | |
Returns: | |
Normalized image data. Data range [0, 1]. | |
""" | |
return image.astype(np.float64) / 255.0 | |
def unnormalize(image: np.ndarray) -> np.ndarray: | |
"""Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data. | |
Args: | |
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``. | |
Returns: | |
Denormalized image data. Data range [0, 255]. | |
""" | |
return image.astype(np.float64) * 255.0 | |
def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor: | |
"""Convert ``PIL.Image`` to Tensor. | |
Args: | |
image (np.ndarray): The image data read by ``PIL.Image`` | |
range_norm (bool): Scale [0, 1] data to between [-1, 1] | |
half (bool): Whether to convert torch.float32 similarly to torch.half type. | |
Returns: | |
Normalized image data | |
Examples: | |
>>> image = Image.open("image.bmp") | |
>>> tensor_image = image2tensor(image, range_norm=False, half=False) | |
""" | |
tensor = F.to_tensor(image) | |
if range_norm: | |
tensor = tensor.mul_(2.0).sub_(1.0) | |
if half: | |
tensor = tensor.half() | |
return tensor | |
def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any: | |
"""Converts ``torch.Tensor`` to ``PIL.Image``. | |
Args: | |
tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image`` | |
range_norm (bool): Scale [-1, 1] data to between [0, 1] | |
half (bool): Whether to convert torch.float32 similarly to torch.half type. | |
Returns: | |
Convert image data to support PIL library | |
Examples: | |
>>> tensor = torch.randn([1, 3, 128, 128]) | |
>>> image = tensor2image(tensor, range_norm=False, half=False) | |
""" | |
if range_norm: | |
tensor = tensor.add_(1.0).div_(2.0) | |
if half: | |
tensor = tensor.half() | |
image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8") | |
return image | |