BhumikaMak commited on
Commit
773cb9c
·
verified ·
1 Parent(s): 173d15f
Files changed (1) hide show
  1. yolov5.py +153 -77
yolov5.py CHANGED
@@ -82,90 +82,166 @@ def xai_yolov5(image):
82
  return Image.fromarray(final_image), caption
83
 
84
 
85
- """
86
- import yaml
87
- import torch
88
- import warnings
89
- warnings.filterwarnings('ignore')
90
- from PIL import Image
91
  import numpy as np
92
- import requests
93
- import cv2
94
  import torch
95
- from pytorch_grad_cam import DeepFeatureFactorization
96
- from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
97
- from pytorch_grad_cam.utils.image import deprocess_image, show_factorization_on_image
 
 
 
 
 
 
 
98
 
99
- # Check if CUDA is available
100
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
- mean = [0.485, 0.456, 0.406] # Mean for RGB channels
102
- std = [0.229, 0.224, 0.225] # Standard deviation for RGB channels
103
- # Load YOLOv5 model and move it to the appropriate device
104
- model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True).to(device)
105
- print(f"Loaded YOLOv5 model on {device}")
106
- def create_labels(concept_scores, top_k=2):
107
-
108
- yolov5_categories_url = \
109
- "https://github.com/ultralytics/yolov5/raw/master/data/coco128.yaml" # URL to the YOLOv5 categories file
110
- yaml_data = requests.get(yolov5_categories_url).text
111
- labels = yaml.safe_load(yaml_data)['names'] # Parse the YAML file to get class names
112
-
113
- concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k]
114
- concept_labels_topk = []
115
- for concept_index in range(concept_categories.shape[0]):
116
- categories = concept_categories[concept_index, :]
117
- concept_labels = []
118
- for category in categories:
119
- score = concept_scores[concept_index, category]
120
- label = f"{labels[category]}:{score:.2f}"
121
- concept_labels.append(label)
122
- concept_labels_topk.append("\n".join(concept_labels))
123
- return concept_labels_topk
124
-
125
- def get_image_from_url(url, device):
126
-
127
-
128
- img = np.array(Image.open(os.path.join(os.getcwd(), "data/xai/sample1.jpeg")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  img = cv2.resize(img, (640, 640))
130
- rgb_img_float = np.float32(img) /255.0
131
- input_tensor = torch.from_numpy(rgb_img_float).permute(2, 0, 1).unsqueeze(0).to(device)
132
- return img, rgb_img_float, input_tensor
133
 
134
- def visualize_image(model, img_url, n_components=20, top_k=1, lyr_idx = 2):
135
- img, rgb_img_float, input_tensor = get_image_from_url(img_url, device)
136
-
137
- # Specify the target layer for DeepFeatureFactorization (e.g., YOLO's backbone)
138
- target_layer = model.model.model.model[-lyr_idx] # Select a feature extraction layer
139
-
140
- dff = DeepFeatureFactorization(model=model.model, target_layer=target_layer)
141
-
142
- # Run DFF on the input tensor
143
- concepts, batch_explanations = dff(input_tensor, n_components)
144
-
145
- # Softmax normalization
146
- concept_outputs = torch.softmax(torch.from_numpy(concepts), axis=-1).numpy()
147
- concept_label_strings = create_labels(concept_outputs, top_k=top_k)
148
-
149
- # Visualize explanations
150
- visualization = show_factorization_on_image(rgb_img_float,
151
- batch_explanations[0],
152
- image_weight=0.2,
153
- concept_labels=concept_label_strings)
154
 
155
- import matplotlib.pyplot as plt
156
- plt.imshow(visualization)
157
- plt.savefig("test" + str(lyr_idx) + ".png")
158
- result = np.hstack((img, visualization))
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- # Resize for visualization
162
- if result.shape[0] > 500:
163
- result = cv2.resize(result, (result.shape[1]//4, result.shape[0]//4))
164
-
165
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- # Test with images
168
- for indx in range(2,12):
169
- Image.fromarray(visualize_image(model,
170
- "https://github.com/jacobgil/pytorch-grad-cam/blob/master/examples/both.png?raw=true", lyr_idx = indx))
171
- """
 
82
  return Image.fromarray(final_image), caption
83
 
84
 
85
+
 
 
 
 
 
86
  import numpy as np
87
+ from PIL import Image
 
88
  import torch
89
+ import cv2
90
+ from typing import Callable, List, Tuple, Optional
91
+ from sklearn.decomposition import NMF
92
+ from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
93
+ from pytorch_grad_cam.utils.image import scale_cam_image, create_labels_legend, show_factorization_on_image
94
+ import matplotlib.pyplot as plt
95
+ from pytorch_grad_cam.utils.image import show_factorization_on_image
96
+ import requests
97
+ import yaml
98
+ import matplotlib.patches as patches
99
 
 
100
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
+ def dff_l(activations, model, n_components):
102
+ batch_size, channels, h, w = activations.shape
103
+ print('activation', activations.shape)
104
+ target_layer_index = 4
105
+ reshaped_activations = activations.transpose((1, 0, 2, 3))
106
+ reshaped_activations[np.isnan(reshaped_activations)] = 0
107
+ reshaped_activations = reshaped_activations.reshape(
108
+ reshaped_activations.shape[0], -1)
109
+ offset = reshaped_activations.min(axis=-1)
110
+ reshaped_activations = reshaped_activations - offset[:, None]
111
+ model = NMF(n_components=n_components, init='random', random_state=0)
112
+ W = model.fit_transform(reshaped_activations)
113
+ H = model.components_
114
+ concepts = W + offset[:, None]
115
+ explanations = H.reshape(n_components, batch_size, h, w)
116
+ explanations = explanations.transpose((1, 0, 2, 3))
117
+ return concepts, explanations
118
+
119
+ class DeepFeatureFactorization:
120
+ def __init__(self,
121
+ model: torch.nn.Module,
122
+ target_layer: torch.nn.Module,
123
+ reshape_transform: Callable = None,
124
+ computation_on_concepts=None
125
+ ):
126
+ self.model = model
127
+ self.computation_on_concepts = computation_on_concepts
128
+ self.activations_and_grads = ActivationsAndGradients(
129
+ self.model, [target_layer], reshape_transform)
130
+
131
+ def __call__(self,
132
+ input_tensor: torch.Tensor,
133
+ model: torch.nn.Module,
134
+ n_components: int = 16):
135
+ if isinstance(input_tensor, np.ndarray):
136
+ input_tensor = torch.from_numpy(input_tensor) # Convert NumPy array
137
+
138
+ batch_size, channels, h, w = input_tensor.size()
139
+ _ = self.activations_and_grads(input_tensor)
140
+
141
+ with torch.no_grad():
142
+ activations = self.activations_and_grads.activations[0].cpu(
143
+ ).numpy()
144
+
145
+ concepts, explanations = dff_l(activations, model, n_components=n_components)
146
+ processed_explanations = []
147
+
148
+ for batch in explanations:
149
+ processed_explanations.append(scale_cam_image(batch, (w, h)))
150
+
151
+ if self.computation_on_concepts:
152
+ with torch.no_grad():
153
+ concept_tensors = torch.from_numpy(
154
+ np.float32(concepts).transpose((1, 0)))
155
+ concept_outputs = self.computation_on_concepts(
156
+ concept_tensors).cpu().numpy()
157
+ return concepts, processed_explanations, concept_outputs
158
+ else:
159
+ return concepts, processed_explanations, explanations
160
+
161
+ def __del__(self):
162
+ self.activations_and_grads.release()
163
+
164
+ def __exit__(self, exc_type, exc_value, exc_tb):
165
+ self.activations_and_grads.release()
166
+ if isinstance(exc_value, IndexError):
167
+ # Handle IndexError here...
168
+ print(
169
+ f"An exception occurred in ActivationSummary with block: {exc_type}. Message: {exc_value}")
170
+ return True
171
+
172
+
173
+ def dff_nmf(image, target_lyr, n_components):
174
+ mean = [0.485, 0.456, 0.406] # Mean for RGB channels
175
+ std = [0.229, 0.224, 0.225] # Standard deviation for RGB channels
176
+ img, rgb_img_float, input_tensor = image.to(device)
177
  img = cv2.resize(img, (640, 640))
178
+ rgb_img_float = np.float32(img) / 255.0
179
+ input_tensor = torch.from_numpy(rgb_img_float).permute(2, 0, 1).unsqueeze(0).to(device)
 
180
 
181
+ model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True).to(device)
182
+ dff= DeepFeatureFactorization(model=model,
183
+ target_layer=model.model.model.model[int(target_lyr)],
184
+ computation_on_concepts=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ concepts, batch_explanations, explanations = dff(input_tensor, model, n_components)
 
 
 
187
 
188
+
189
+ yolov5_categories_url = \
190
+ "https://github.com/ultralytics/yolov5/raw/master/data/coco128.yaml" # URL to the YOLOv5 categories file
191
+ yaml_data = requests.get(yolov5_categories_url).text
192
+ labels = yaml.safe_load(yaml_data)['names'] # Parse the YAML file to get class names
193
+ num_classes = model.model.model.model[-1].nc
194
+
195
+ for indx in range( explanations[0].shape[0]):
196
+ upsampled_input = explanations[0][indx]
197
+ upsampled_input = torch.tensor(upsampled_input)
198
+ device = next(model.parameters()).device
199
+ input_tensor = upsampled_input.unsqueeze(0)
200
+ input_tensor = input_tensor.unsqueeze(1).repeat(1, 128, 1, 1)
201
+ detection_lyr = model.model.model.model[-1]
202
+ output1 = detection_lyr.m[0](input_tensor.to(device))
203
+ objectness = output1[..., 4] # Objectness score (index 4)
204
+ class_scores = output1[..., 5:] # Class scores (from index 5 onwards, representing 80 classes)
205
+ objectness = torch.sigmoid(objectness)
206
+ class_scores = torch.sigmoid(class_scores)
207
+ confidence_mask = objectness > 0.5
208
+ objectness = objectness[confidence_mask]
209
+ class_scores = class_scores[confidence_mask]
210
+ scores, class_ids = class_scores.max(dim=-1) # Get max class score per cell
211
+ scores = scores * objectness # Adjust scores by objectness
212
+ boxes = output1[..., :4] # First 4 values are x1, y1, x2, y2
213
+ boxes = boxes[confidence_mask] # Filter boxes by confidence mask
214
+ fig, ax = plt.subplots(1, figsize=(10, 10))
215
+ ax.imshow(torch.tensor(batch_explanations[0][indx]).cpu().numpy(), cmap="gray") # Display image
216
+ top_score_idx = scores.argmax(dim=0) # Get the index of the max score
217
+ top_score = scores[top_score_idx].item()
218
+ top_class_id = class_ids[top_score_idx].item()
219
+ top_box = boxes[top_score_idx].cpu().numpy()
220
+ scale_factor = 16
221
+ x1, y1, x2, y2 = top_box
222
+ x1, y1, x2, y2 = x1 * scale_factor, y1 * scale_factor, x2 * scale_factor, y2 * scale_factor
223
+ rect = patches.Rectangle(
224
+ (x1, y1), x2 - x1, y2 - y1,
225
+ linewidth=2, edgecolor='r', facecolor='none')
226
+ ax.add_patch(rect)
227
 
228
+ predicted_label = labels[top_class_id] # Map ID to label
229
+ ax.text(x1, y1, f"{predicted_label}: {top_score:.2f}",
230
+ color='r', fontsize=12, verticalalignment='top')
231
+ plt.show()
232
+ plt.savefig("test_" + str(indx) + ".png" )
233
+ plt.clf()
234
+ return rgb_img_float, explanations
235
+
236
+
237
+ def visualize_batch_explanations(rgb_img_float, batch_explanations, image_weight=0.7):
238
+ for i, explanation in enumerate(batch_explanations):
239
+ # Create visualization for each explanation
240
+ visualization = show_factorization_on_image(rgb_img_float, explanation, image_weight=image_weight)
241
+ plt.figure()
242
+ plt.imshow(visualization) # Correctly pass the visualization data
243
+ plt.title(f'Explanation {i + 1}') # Set the title for each plot
244
+ plt.axis('off') # Hide axes
245
+ plt.show() # Show the plot
246
+ plt.savefig("test_w.png")
247