gagan3012's picture
Create app.py
a29a059
raw
history blame
No virus
2.36 kB
import gradio as gr
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import kornia as K
def inference(file1,num_iters):
img: np.ndarray = cv2.imread(file1.name)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
img = img + np.random.normal(loc=0.0, scale=0.1, size=img.shape)
img = np.clip(img, 0.0, 1.0)
# convert to torch tensor
noisy_image = K.utils.image_to_tensor(img).squeeze()
class TVDenoise(torch.nn.Module):
def __init__(self, noisy_image):
super(TVDenoise, self).__init__()
self.l2_term = torch.nn.MSELoss(reduction='mean')
self.regularization_term = K.losses.TotalVariation()
# create the variable which will be optimized to produce the noise free image
self.clean_image = torch.nn.Parameter(data=noisy_image.clone(), requires_grad=True)
self.noisy_image = noisy_image
def forward(self):
# print(self.l2_term(self.clean_image, self.noisy_image))
# print(self.regularization_term(self.clean_image))
return self.l2_term(self.clean_image, self.noisy_image) + 0.0001 * self.regularization_term(self.clean_image)
def get_clean_image(self):
return self.clean_image
tv_denoiser = TVDenoise(noisy_image)
# define the optimizer to optimize the 1 parameter of tv_denoiser
optimizer = torch.optim.SGD(tv_denoiser.parameters(), lr=0.1, momentum=0.9)
for i in range(int(num_iters)):
optimizer.zero_grad()
loss = torch.mean(tv_denoiser())
if i % 50 == 0:
print("Loss in iteration {} of {}: {:.3f}".format(i, num_iters, loss.item()))
loss.backward()
optimizer.step()
img_clean: np.ndarray = K.utils.tensor_to_image(tv_denoiser.get_clean_image())
return img, img_clean
examples = [
]
inputs = [
gr.Image(type='file', label='Input Image'),
gr.Slider(minimum=50, maximum=10000, step=50, default=500, label="num_iters")
]
outputs = [
gr.Image(type='file', label='Noised Image'),
gr.Image(type='file', label='Denoised Image'),
]
title = "Image Stitching using Kornia and LoFTR"
demo_app = gr.Interface(
fn=inference,
inputs=inputs,
outputs=outputs,
title=title,
examples=examples,
theme='huggingface',
)
demo_app.launch(debug=True)