hank1996 commited on
Commit
8a3c262
·
1 Parent(s): 8370e1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -37
app.py CHANGED
@@ -84,25 +84,12 @@ def detect(img,model):
84
  print(weights)
85
  stride =32
86
  model = torch.jit.load(weights,map_location=device)
87
- print(model)
88
- imgsz = check_img_size(imgsz, s=stride)
89
- #model = model.to(device)
90
- #print(111111111)
91
-
92
 
93
  # Set Dataloader
94
  vid_path, vid_writer = None, None
95
- #if webcam:
96
- #view_img = check_imshow()
97
- #cudnn.benchmark = True # set True to speed up constant image size inference
98
- #dataset = LoadStreams(source, img_size=imgsz, stride=stride)
99
- #else:
100
  dataset = LoadImages(source, img_size=imgsz, stride=stride)
101
 
102
- # Get names and colors
103
- names = model.module.names if hasattr(model, 'module') else model.names
104
- colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
105
-
106
  # Run inference
107
  if device.type != 'cpu':
108
  model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
@@ -111,23 +98,32 @@ def detect(img,model):
111
  img = torch.from_numpy(img).to(device)
112
  img = img.half() if half else img.float() # uint8 to fp16/32
113
  img /= 255.0 # 0 - 255 to 0.0 - 1.0
 
114
  if img.ndimension() == 3:
115
  img = img.unsqueeze(0)
116
 
117
  # Inference
118
  t1 = time_synchronized()
119
- pred = model(img, augment=opt.augment)[0]
 
 
 
 
 
 
 
120
 
121
  # Apply NMS
 
122
  pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
123
- t2 = time_synchronized()
 
 
 
124
 
125
-
126
  # Process detections
127
  for i, det in enumerate(pred): # detections per image
128
- #if webcam: # batch_size >= 1
129
- #p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
130
- #else:
131
  p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
132
 
133
  p = Path(p) # to Path
@@ -142,7 +138,7 @@ def detect(img,model):
142
  # Print results
143
  for c in det[:, -1].unique():
144
  n = (det[:, -1] == c).sum() # detections per class
145
- s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
146
 
147
  # Write results
148
  for *xyxy, conf, cls in reversed(det):
@@ -152,22 +148,18 @@ def detect(img,model):
152
  with open(txt_path + '.txt', 'a') as f:
153
  f.write(('%g ' * len(line)).rstrip() % line + '\n')
154
 
155
- if save_img or view_img: # Add bbox to image
156
- label = f'{names[int(cls)]} {conf:.2f}'
157
- plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
158
 
159
- # Print time (inference + NMS)
160
- #print(f'{s}Done. ({t2 - t1:.3f}s)')
161
-
162
- # Stream results
163
- if view_img:
164
- cv2.imshow(str(p), im0)
165
- cv2.waitKey(1) # 1 millisecond
166
 
167
  # Save results (image with detections)
168
  if save_img:
169
  if dataset.mode == 'image':
170
  cv2.imwrite(save_path, im0)
 
171
  else: # 'video' or 'stream'
172
  if vid_path != save_path: # new video
173
  vid_path = save_path
@@ -175,18 +167,19 @@ def detect(img,model):
175
  vid_writer.release() # release previous video writer
176
  if vid_cap: # video
177
  fps = vid_cap.get(cv2.CAP_PROP_FPS)
178
- w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
179
- h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
180
  else: # stream
181
  fps, w, h = 30, im0.shape[1], im0.shape[0]
182
  save_path += '.mp4'
183
  vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
184
  vid_writer.write(im0)
185
 
186
- if save_txt or save_img:
187
- s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
188
- #print(f"Results saved to {save_dir}{s}")
189
-
190
  print(f'Done. ({time.time() - t0:.3f}s)')
191
 
192
  return Image.fromarray(im0[:,:,::-1])
 
84
  print(weights)
85
  stride =32
86
  model = torch.jit.load(weights,map_location=device)
87
+ model.eval()
 
 
 
 
88
 
89
  # Set Dataloader
90
  vid_path, vid_writer = None, None
 
 
 
 
 
91
  dataset = LoadImages(source, img_size=imgsz, stride=stride)
92
 
 
 
 
 
93
  # Run inference
94
  if device.type != 'cpu':
95
  model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
 
98
  img = torch.from_numpy(img).to(device)
99
  img = img.half() if half else img.float() # uint8 to fp16/32
100
  img /= 255.0 # 0 - 255 to 0.0 - 1.0
101
+
102
  if img.ndimension() == 3:
103
  img = img.unsqueeze(0)
104
 
105
  # Inference
106
  t1 = time_synchronized()
107
+ [pred,anchor_grid],seg,ll= model(img)
108
+ t2 = time_synchronized()
109
+
110
+ # waste time: the incompatibility of torch.jit.trace causes extra time consumption in demo version
111
+ # but this problem will not appear in offical version
112
+ tw1 = time_synchronized()
113
+ pred = split_for_trace_model(pred,anchor_grid)
114
+ tw2 = time_synchronized()
115
 
116
  # Apply NMS
117
+ t3 = time_synchronized()
118
  pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
119
+ t4 = time_synchronized()
120
+
121
+ da_seg_mask = driving_area_mask(seg)
122
+ ll_seg_mask = lane_line_mask(ll)
123
 
 
124
  # Process detections
125
  for i, det in enumerate(pred): # detections per image
126
+
 
 
127
  p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
128
 
129
  p = Path(p) # to Path
 
138
  # Print results
139
  for c in det[:, -1].unique():
140
  n = (det[:, -1] == c).sum() # detections per class
141
+ #s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
142
 
143
  # Write results
144
  for *xyxy, conf, cls in reversed(det):
 
148
  with open(txt_path + '.txt', 'a') as f:
149
  f.write(('%g ' * len(line)).rstrip() % line + '\n')
150
 
151
+ if save_img : # Add bbox to image
152
+ plot_one_box(xyxy, im0, line_thickness=3)
 
153
 
154
+ # Print time (inference)
155
+ print(f'{s}Done. ({t2 - t1:.3f}s)')
156
+ show_seg_result(im0, (da_seg_mask,ll_seg_mask), is_demo=True)
 
 
 
 
157
 
158
  # Save results (image with detections)
159
  if save_img:
160
  if dataset.mode == 'image':
161
  cv2.imwrite(save_path, im0)
162
+ print(f" The image with the result is saved in: {save_path}")
163
  else: # 'video' or 'stream'
164
  if vid_path != save_path: # new video
165
  vid_path = save_path
 
167
  vid_writer.release() # release previous video writer
168
  if vid_cap: # video
169
  fps = vid_cap.get(cv2.CAP_PROP_FPS)
170
+ #w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
171
+ #h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
172
+ w,h = im0.shape[1], im0.shape[0]
173
  else: # stream
174
  fps, w, h = 30, im0.shape[1], im0.shape[0]
175
  save_path += '.mp4'
176
  vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
177
  vid_writer.write(im0)
178
 
179
+ inf_time.update(t2-t1,img.size(0))
180
+ nms_time.update(t4-t3,img.size(0))
181
+ waste_time.update(tw2-tw1,img.size(0))
182
+ print('inf : (%.4fs/frame) nms : (%.4fs/frame)' % (inf_time.avg,nms_time.avg))
183
  print(f'Done. ({time.time() - t0:.3f}s)')
184
 
185
  return Image.fromarray(im0[:,:,::-1])