|
import os |
|
import onnxruntime |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
|
|
|
|
onnx_model_path = "sar2rgb.onnx" |
|
sess = onnxruntime.InferenceSession(onnx_model_path) |
|
|
|
|
|
def predict(input_image): |
|
|
|
input_image = input_image.resize((256, 256)) |
|
input_image = np.array(input_image).transpose(2, 0, 1) |
|
input_image = input_image.astype(np.float32) / 255.0 |
|
input_image = (input_image - 0.5) / 0.5 |
|
input_image = np.expand_dims(input_image, axis=0) |
|
|
|
|
|
inputs = {sess.get_inputs()[0].name: input_image} |
|
output = sess.run(None, inputs) |
|
|
|
|
|
output_image = output[0].squeeze().transpose(1, 2, 0) |
|
output_image = (output_image + 1) / 2 |
|
output_image = (output_image * 255).astype(np.uint8) |
|
|
|
return Image.fromarray(output_image) |
|
|
|
|
|
example_images = [[os.path.join("examples", fname)] for fname in os.listdir("examples")] |
|
|
|
|
|
iface = gr.Interface(fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Image(type="pil"), |
|
examples=example_images |
|
) |
|
|
|
|
|
iface.launch() |
|
|