Spaces:
Sleeping
Sleeping
import gradio as gr | |
import netron | |
import os | |
import threading | |
import time | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
from yolov5 import xai_yolov5 | |
from yolov8 import xai_yolov8s | |
import requests | |
""" | |
# Sample images directory | |
sample_images = { | |
"Sample 1": os.path.join(os.getcwd(), "data/xai/sample1.jpeg"), | |
"Sample 2": os.path.join(os.getcwd(), "data/xai/sample2.jpg"), | |
} | |
# Preloaded model file path | |
preloaded_model_file = os.path.join(os.getcwd(), "weight_files/yolov5.onnx") # Example path | |
# Function to load sample image | |
def load_sample_image(sample_name): | |
image_path = sample_images.get(sample_name) | |
if image_path and os.path.exists(image_path): | |
return Image.open(image_path) | |
return None | |
# Function to process the image | |
def process_image(sample_choice, uploaded_image, yolo_versions): | |
# Use uploaded or sample image | |
if uploaded_image is not None: | |
image = uploaded_image | |
else: | |
image = load_sample_image(sample_choice) | |
# Resize and process the image | |
image = np.array(image) | |
image = cv2.resize(image, (640, 640)) | |
result_images = [] | |
for yolo_version in yolo_versions: | |
if yolo_version == "yolov5": | |
result_images.append(xai_yolov5(image)) | |
elif yolo_version == "yolov8s": | |
result_images.append(xai_yolov8s(image)) | |
else: | |
result_images.append((Image.fromarray(image), f"{yolo_version} not implemented.")) | |
return result_images | |
# Start Netron backend | |
def start_netron_backend(model_file): | |
def serve_netron(): | |
netron.start(model_file, address=("0.0.0.0", 8080), browse=False) | |
#netron.start(model_file, address="0.0.0.0:8080", browse=False) # Updated Netron arguments | |
# Launch Netron in a separate thread | |
threading.Thread(target=serve_netron, daemon=True).start() | |
# Wait until Netron server is ready | |
def wait_for_netron(url, timeout=10): | |
start_time = time.time() | |
while time.time() - start_time < timeout: | |
try: | |
response = requests.get(url) | |
if response.status_code == 200: | |
return True | |
except requests.ConnectionError: | |
time.sleep(0.5) | |
return False | |
# Check server readiness | |
wait_for_netron("http://localhost:8080/", timeout=15) | |
# View Netron model | |
def view_netron_model(): | |
# Ensure model exists | |
if not os.path.exists(preloaded_model_file): | |
return "Model file not found." | |
# Start Netron backend | |
start_netron_backend(preloaded_model_file) | |
return gr.HTML('<iframe src="http://localhost:8080/" width="100%" height="600px"></iframe>') | |
# Custom CSS for styling (optional) | |
custom_css = """ | |
#run_button { | |
# background-color: purple; | |
# color: white; | |
# width: 120px; | |
# border-radius: 5px; | |
# font-size: 14px; | |
#} | |
""" | |
# Gradio UI | |
with gr.Blocks(css=custom_css) as interface: | |
gr.Markdown("# XAI: Visualize Object Detection of Your Models") | |
# Default sample | |
default_sample = "Sample 1" | |
with gr.Row(): | |
# Left: Select sample or upload image | |
with gr.Column(): | |
sample_selection = gr.Radio( | |
choices=list(sample_images.keys()), | |
label="Select a Sample Image", | |
type="value", | |
value=default_sample, | |
) | |
upload_image = gr.Image(label="Upload an Image", type="pil") | |
selected_models = gr.CheckboxGroup( | |
choices=["yolov5", "yolov8s"], | |
value=["yolov5"], | |
label="Select Model(s)", | |
) | |
run_button = gr.Button("Run", elem_id="run_button") | |
# Right: Display sample image | |
with gr.Column(): | |
sample_display = gr.Image( | |
value=load_sample_image(default_sample), | |
label="Selected Sample Image", | |
) | |
# Results and Netron | |
with gr.Row(): | |
result_gallery = gr.Gallery( | |
label="Results", | |
elem_id="gallery", | |
rows=1, | |
height=500, | |
) | |
# Display Netron iframe | |
netron_display = gr.HTML(view_netron_model()) | |
# Sample selection update | |
sample_selection.change( | |
fn=load_sample_image, | |
inputs=sample_selection, | |
outputs=sample_display, | |
) | |
# Process image | |
run_button.click( | |
fn=process_image, | |
inputs=[sample_selection, upload_image, selected_models], | |
outputs=[result_gallery], | |
) | |
# Launch Gradio app | |
if __name__ == "__main__": | |
interface.launch(share=True) | |
""" | |
import gradio as gr | |
import onnxruntime | |
import os | |
def visualize_onnx_model(onnx_model_path): | |
""" | |
Visualizes the given ONNX model using Netron. | |
Args: | |
onnx_model_path: The path to the ONNX model file. | |
""" | |
try: | |
# Save the ONNX model to a temporary file (optional, but can be helpful) | |
temp_model_path = "temp_model.onnx" | |
with open(temp_model_path, "wb") as f: | |
f.write(onnx_model_path.read()) # Assuming onnx_model_path is a file-like object | |
# Run Netron | |
os.system(f"netron {temp_model_path}") | |
except Exception as e: | |
print(f"Error visualizing model: {e}") | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=visualize_onnx_model, | |
inputs="file", # Accept the ONNX model as a file upload | |
outputs="text", | |
title="Netron ONNX Model Visualization", | |
description="Upload an ONNX model file to visualize with Netron." | |
) | |
# Launch the Gradio app | |
iface.launch(share=True) |