zca-whitening / app.py
marianna13's picture
Update app.py
855dbc0
raw
history blame
2.14 kB
import gradio as gr
import torch
import kornia as K
import cv2
import numpy as np
from torchvision import transforms
from torchvision.utils import make_grid
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def resize_images(f_names):
for i, f_name in enumerate(f_names):
img = cv2.imread(f_name, cv2.IMREAD_COLOR)
resized_image = cv2.resize(img,(50, 50))
cv2.imwrite(f_name,resized_image)
def predict(images, eps):
eps = float(eps)
f_names = [img.name for img in images]
resize_images(f_names)
convert_tensor = transforms.ToTensor()
images = [convert_tensor(cv2.imread(f, cv2.IMREAD_COLOR)) for f in f_names]
images = torch.stack(images, dim = 0).to(device)
zca = K.enhance.ZCAWhitening(eps=eps, compute_inv=True)
zca.fit(images)
zca_images = zca(images)
grid_zca = make_grid(zca_images, nrow=3, normalize=True).cpu().numpy()
return np.transpose(grid_zca,[1,2,0])
title = 'ZCA Whitening with Kornia!'
description = '''[ZCA Whitening](https://paperswithcode.com/method/zca-whitening) is an image preprocessing method that leads to a transformation of data such that the covariance matrix is the identity matrix, leading to decorrelated features:
*Note that you can upload only image files, e.g. jpg, png etc and there should be atleast 2 images!*
Learn more about [ZCA Whitening and Kornia](https://kornia.readthedocs.io/en/latest/_modules/kornia/enhance/zca.html)'''
iface = gr.Interface(fn=predict,
inputs=['files', gr.Slider(0.01, 1)],
outputs=gr.Image(),
allow_flagging="never",
title=title,
description=description,
examples=[[
[
'irises.jpg',
'roses.jpg',
'sunflower.jpg',
'violets.jpg',
'chamomile.jpg',
'tulips.jpg',
'Alstroemeria.jpg',
'Carnation.jpg',
'Orchid.jpg',
'Peony.jpg'
], 0.1]]
)
if __name__ == "__main__":
iface.launch(show_error=True)