import torch import torch.nn as nn import torch.nn.functional as F import math from torch.nn.init import _calculate_fan_in_and_fan_out from timm.models.layers import to_2tuple, trunc_normal_ import torchvision.transforms as transforms from torchvision import models import gradio as gr from PIL import Image import numpy as np from matplotlib import pyplot as plt from model import dehazeformer_t # Get cpu or gpu device for training. device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"Using {device} device") t_model_load = dehazeformer_t().to(device) t_model_load best_model_weights = torch.load('best_t_model_weights.pth', map_location=torch.device('cpu')) t_model_load.load_state_dict(best_model_weights) def pred_one_image(inp): one_image = np.array(inp.resize((256, 256)).convert("RGB"))/255 # convert to other format HWC -> CHW one_image = np.moveaxis(one_image, -1, 0) # mask = np.expand_dims(mask, 0) one_image = torch.tensor(one_image).float() one_image = one_image.unsqueeze(0) one_image = one_image.to(device) with torch.no_grad(): t_model_load.eval() output = t_model_load(one_image) print(output.shape) output = output[0].cpu().permute((1, 2, 0)) plt.figure(figsize=(10, 10)) plt.imshow(output.numpy()) # convert CHW -> HWC plt.axis("off") # 保存图像,可以指定文件名和格式,例如 'image.png' plt.savefig('image.png', format='png', dpi=300) # dpi是图像的分辨率 out_img = Image.open('image.png') return out_img demo = gr.Interface(fn=pred_one_image, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), examples=['noisy_10961455225_0786d3edd2_c.jpg'], ) demo.launch(debug=True) # demo.launch()