import gradio as gr import cv2 import matplotlib.pyplot as plt import numpy as np import torch import torchvision import kornia as K def inference(file1,num_iters): img: np.ndarray = cv2.imread(file1.name) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 img = img + np.random.normal(loc=0.0, scale=0.1, size=img.shape) img = np.clip(img, 0.0, 1.0) # convert to torch tensor noisy_image = K.utils.image_to_tensor(img).squeeze() class TVDenoise(torch.nn.Module): def __init__(self, noisy_image): super(TVDenoise, self).__init__() self.l2_term = torch.nn.MSELoss(reduction='mean') self.regularization_term = K.losses.TotalVariation() # create the variable which will be optimized to produce the noise free image self.clean_image = torch.nn.Parameter(data=noisy_image.clone(), requires_grad=True) self.noisy_image = noisy_image def forward(self): # print(self.l2_term(self.clean_image, self.noisy_image)) # print(self.regularization_term(self.clean_image)) return self.l2_term(self.clean_image, self.noisy_image) + 0.0001 * self.regularization_term(self.clean_image) def get_clean_image(self): return self.clean_image tv_denoiser = TVDenoise(noisy_image) # define the optimizer to optimize the 1 parameter of tv_denoiser optimizer = torch.optim.SGD(tv_denoiser.parameters(), lr=0.1, momentum=0.9) for i in range(int(num_iters)): optimizer.zero_grad() loss = torch.mean(tv_denoiser()) if i % 50 == 0: print("Loss in iteration {} of {}: {:.3f}".format(i, num_iters, loss.item())) loss.backward() optimizer.step() img_clean: np.ndarray = K.utils.tensor_to_image(tv_denoiser.get_clean_image()) return img, img_clean examples = [ ] inputs = [ gr.Image(type='file', label='Input Image'), gr.Slider(minimum=50, maximum=10000, step=50, default=500, label="num_iters") ] outputs = [ gr.Image(type='file', label='Noised Image'), gr.Image(type='file', label='Denoised Image'), ] title = "Image Stitching using Kornia and LoFTR" demo_app = gr.Interface( fn=inference, inputs=inputs, outputs=outputs, title=title, examples=examples, theme='huggingface', ) demo_app.launch(debug=True)