gagan3012's picture
Update app.py
42acc98
raw
history blame
2.38 kB
import gradio as gr
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import kornia as K
def inference(file1,num_iters):
img: np.ndarray = cv2.imread(file1.name)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
img = img + np.random.normal(loc=0.0, scale=0.1, size=img.shape)
img = np.clip(img, 0.0, 1.0)
# convert to torch tensor
noisy_image = K.utils.image_to_tensor(img).squeeze()
class TVDenoise(torch.nn.Module):
def __init__(self, noisy_image):
super(TVDenoise, self).__init__()
self.l2_term = torch.nn.MSELoss(reduction='mean')
self.regularization_term = K.losses.TotalVariation()
# create the variable which will be optimized to produce the noise free image
self.clean_image = torch.nn.Parameter(data=noisy_image.clone(), requires_grad=True)
self.noisy_image = noisy_image
def forward(self):
# print(self.l2_term(self.clean_image, self.noisy_image))
# print(self.regularization_term(self.clean_image))
return self.l2_term(self.clean_image, self.noisy_image) + 0.0001 * self.regularization_term(self.clean_image)
def get_clean_image(self):
return self.clean_image
tv_denoiser = TVDenoise(noisy_image)
# define the optimizer to optimize the 1 parameter of tv_denoiser
optimizer = torch.optim.SGD(tv_denoiser.parameters(), lr=0.1, momentum=0.9)
for i in range(int(num_iters)):
optimizer.zero_grad()
loss = torch.mean(tv_denoiser())
if i % 50 == 0:
print("Loss in iteration {} of {}: {:.3f}".format(i, num_iters, loss.item()))
loss.backward()
optimizer.step()
img_clean: np.ndarray = K.utils.tensor_to_image(tv_denoiser.get_clean_image())
return img, img_clean
examples = [ ["doraemon.png",2000]
]
inputs = [
gr.Image(type='file', label='Input Image'),
gr.Slider(minimum=50, maximum=10000, step=50, default=500, label="num_iters")
]
outputs = [
gr.Image(type='file', label='Noised Image'),
gr.Image(type='file', label='Denoised Image'),
]
title = "Image Stitching using Kornia and LoFTR"
demo_app = gr.Interface(
fn=inference,
inputs=inputs,
outputs=outputs,
title=title,
examples=examples,
theme='huggingface',
)
demo_app.launch(debug=True)