BhumikaMak commited on
Commit
dbd2a18
·
1 Parent(s): b16b91f

Add: model summary support

Browse files
Files changed (1) hide show
  1. app.py +85 -44
app.py CHANGED
@@ -1,42 +1,40 @@
1
- import numpy as np
2
- import cv2
3
  import os
 
 
4
  from PIL import Image
5
- import torchvision.transforms as transforms
6
- import gradio as gr
7
- from yolov5 import xai_yolov5
8
- from yolov8 import xai_yolov8s
9
 
 
10
  sample_images = {
11
  "Sample 1": os.path.join(os.getcwd(), "data/xai/sample1.jpeg"),
12
- "Sample 2": os.path.join(os.getcwd(), "data/xai/sample2.jpg"),
13
  }
 
 
 
 
14
  def load_sample_image(sample_name):
 
15
  image_path = sample_images.get(sample_name)
16
  if image_path and os.path.exists(image_path):
17
  return Image.open(image_path)
18
  return None
19
 
20
- default_sample_image = load_sample_image("Sample 1")
21
-
22
- def load_sample_image(choice):
23
- if choice in sample_images:
24
- image_path = sample_images[choice]
25
- return cv2.imread(image_path)[:, :, ::-1]
26
- else:
27
- raise ValueError("Invalid sample selection.")
28
-
29
-
30
- def process_image(sample_choice, uploaded_image, yolo_versions=["yolov5"]):
31
- print(sample_choice, upload_image)
32
  if uploaded_image is not None:
33
  image = uploaded_image # Use the uploaded image
34
  else:
35
- # Otherwise, use the selected sample image
36
- image = load_sample_image(sample_choice)
37
  image = np.array(image)
38
  image = cv2.resize(image, (640, 640))
39
  result_images = []
 
40
  for yolo_version in yolo_versions:
41
  if yolo_version == "yolov5":
42
  result_images.append(xai_yolov5(image))
@@ -44,50 +42,93 @@ def process_image(sample_choice, uploaded_image, yolo_versions=["yolov5"]):
44
  result_images.append(xai_yolov8s(image))
45
  else:
46
  result_images.append((Image.fromarray(image), f"{yolo_version} not yet implemented."))
 
47
  return result_images
48
 
49
- with gr.Blocks() as interface:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  gr.Markdown("# XAI: Visualize Object Detection of Your Models")
51
- gr.Markdown("Select a sample image to visualize object detection.")
52
  default_sample = "Sample 1"
53
- with gr.Row(elem_classes="orchid-green-bg"):
 
54
  # Left side: Sample selection and upload image
55
  with gr.Column():
56
  sample_selection = gr.Radio(
57
  choices=list(sample_images.keys()),
58
  label="Select a Sample Image",
59
  type="value",
60
- value=default_sample, # Set default selection
61
  )
62
- # Upload image below sample selection
63
- gr.Markdown("**Or upload your own image:**")
64
  upload_image = gr.Image(
65
  label="Upload an Image",
66
- type="pil", # Correct type for file path compatibility
67
  )
68
- # Right side: Selected sample image display
69
- sample_display = gr.Image(
70
- value=load_sample_image(default_sample),
71
- label="Selected Sample Image",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  )
73
-
 
 
74
  sample_selection.change(
75
  fn=load_sample_image,
76
  inputs=sample_selection,
77
  outputs=sample_display,
78
  )
79
 
80
- selected_models = gr.CheckboxGroup(
81
- choices=["yolov5", "yolov8s"],
82
- value=["yolov5"],
83
- label="Select Model(s)",
84
- )
85
- result_gallery = gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500)
86
-
87
- gr.Button("Run").click(
88
  fn=process_image,
89
- inputs=[sample_selection, upload_image, selected_models], # Include both options
90
- outputs=result_gallery,
91
  )
92
 
93
- interface.launch(share=True)
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import netron
3
  import os
4
+ import threading
5
+ import time
6
  from PIL import Image
7
+ import cv2
8
+ import numpy as np
9
+ import torch
 
10
 
11
+ # Sample images directory
12
  sample_images = {
13
  "Sample 1": os.path.join(os.getcwd(), "data/xai/sample1.jpeg"),
14
+ "Sample 2": os.path.join(os.getcwd(), "data/xai/sample2.jpg"),
15
  }
16
+
17
+ # Preloaded model file path (update this path as needed)
18
+ preloaded_model_file = os.path.join(os.getcwd(), "weight_files/yolov5.onnx") # Example path
19
+
20
  def load_sample_image(sample_name):
21
+ """Load a sample image based on user selection."""
22
  image_path = sample_images.get(sample_name)
23
  if image_path and os.path.exists(image_path):
24
  return Image.open(image_path)
25
  return None
26
 
27
+ def process_image(sample_choice, uploaded_image, yolo_versions):
28
+ """Process the image using selected YOLO models."""
 
 
 
 
 
 
 
 
 
 
29
  if uploaded_image is not None:
30
  image = uploaded_image # Use the uploaded image
31
  else:
32
+ image = load_sample_image(sample_choice) # Use selected sample image
33
+
34
  image = np.array(image)
35
  image = cv2.resize(image, (640, 640))
36
  result_images = []
37
+
38
  for yolo_version in yolo_versions:
39
  if yolo_version == "yolov5":
40
  result_images.append(xai_yolov5(image))
 
42
  result_images.append(xai_yolov8s(image))
43
  else:
44
  result_images.append((Image.fromarray(image), f"{yolo_version} not yet implemented."))
45
+
46
  return result_images
47
 
48
+ def serve_netron(model_file):
49
+ """Start the Netron server in a separate thread."""
50
+ threading.Thread(target=netron.start, args=(model_file,), daemon=True).start()
51
+ time.sleep(1) # Give some time for the server to start
52
+ return "http://localhost:8080" # Default Netron URL
53
+ def view_model():
54
+ """Handle model visualization using preloaded model file."""
55
+ if not os.path.exists(preloaded_model_file):
56
+ return "Model file not found."
57
+
58
+ netron_url = serve_netron(preloaded_model_file)
59
+ return f'<iframe src="{netron_url}" width="100%" height="600px"></iframe>'
60
+
61
+ # Custom CSS for styling (optional)
62
+ custom_css = """
63
+ #run_button {
64
+ background-color: purple;
65
+ color: white;
66
+ width: 120px;
67
+ border-radius: 5px;
68
+ font-size: 14px;
69
+ }
70
+ """
71
+
72
+ with gr.Blocks(css=custom_css) as interface:
73
  gr.Markdown("# XAI: Visualize Object Detection of Your Models")
74
+
75
  default_sample = "Sample 1"
76
+
77
+ with gr.Row():
78
  # Left side: Sample selection and upload image
79
  with gr.Column():
80
  sample_selection = gr.Radio(
81
  choices=list(sample_images.keys()),
82
  label="Select a Sample Image",
83
  type="value",
84
+ value=default_sample,
85
  )
86
+
 
87
  upload_image = gr.Image(
88
  label="Upload an Image",
89
+ type="pil",
90
  )
91
+
92
+ selected_models = gr.CheckboxGroup(
93
+ choices=["yolov5", "yolov8s"],
94
+ value=["yolov5"],
95
+ label="Select Model(s)",
96
+ )
97
+
98
+ run_button = gr.Button("Run", elem_id="run_button")
99
+
100
+ with gr.Column():
101
+ sample_display = gr.Image(
102
+ value=load_sample_image(default_sample),
103
+ label="Selected Sample Image",
104
+ )
105
+
106
+ # Below the sample image, display results and architecture side by side
107
+ with gr.Row():
108
+ result_gallery = gr.Gallery(
109
+ label="Results",
110
+ elem_id="gallery",
111
+ rows=1,
112
+ height=500,
113
  )
114
+
115
+ netron_display = gr.HTML(label="Netron Visualization")
116
+
117
  sample_selection.change(
118
  fn=load_sample_image,
119
  inputs=sample_selection,
120
  outputs=sample_display,
121
  )
122
 
123
+ run_button.click(
 
 
 
 
 
 
 
124
  fn=process_image,
125
+ inputs=[sample_selection, upload_image, selected_models],
126
+ outputs=[result_gallery],
127
  )
128
 
129
+ # Update Netron display when the interface loads
130
+ netron_display.value = view_model() # Directly set the value
131
+
132
+ # Launching Gradio app and handling Netron visualization separately.
133
+ if __name__ == "__main__":
134
+ interface.launch(share=True)