Matteo Sirri commited on
Commit
91ccbe1
1 Parent(s): b4616bc

feat: add path baseline

Browse files
Files changed (2) hide show
  1. app.py +10 -10
  2. model_baseline_coco_FT_MOT17.pth +3 -0
app.py CHANGED
@@ -12,20 +12,20 @@ import torchvision.transforms as T
12
 
13
  logging.getLogger('PIL').setLevel(logging.CRITICAL)
14
 
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
 
18
  def load_model(baseline: bool = False):
19
  if baseline:
20
- model = fasterrcnn_resnet50_fpn(
21
- weights="DEFAULT")
22
  else:
23
- model = fasterrcnn_resnet50_fpn()
24
- in_features = model.roi_heads.box_predictor.cls_score.in_features
25
- model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
26
- checkpoint = torch.load(
27
- osp.join(os.getcwd(), "model_split3_FT_MOT17.pth"), map_location="cpu")
28
- model.load_state_dict(checkpoint["model"])
 
29
  model.to(device)
30
  model.eval()
31
  return model
@@ -56,7 +56,7 @@ def frcnn_coco(image):
56
  title = "Domain shift adaption on pedestrian detection with Faster R-CNN"
57
  description = '<p style="text-align:center">School in AI: Deep Learning, Vision and Language for Industry - second edition final project work by Matteo Sirri.</p> '
58
  examples = ["001.jpg", "002.jpg", "003.jpg",
59
- "004.jpg", "005.jpg", "006.jpg", "007.jpg", ]
60
 
61
  io_baseline = gr.Interface(frcnn_coco, gr.Image(type="pil"), gr.Image(
62
  type="file", label="Baseline Model trained on COCO + FT on MOT17"))
 
12
 
13
  logging.getLogger('PIL').setLevel(logging.CRITICAL)
14
 
15
+ device = torch.device("cpu")
16
 
17
 
18
  def load_model(baseline: bool = False):
19
  if baseline:
20
+ path = osp.join(os.getcwd(), "model_baseline_coco_FT_MOT17.pth")
 
21
  else:
22
+ path = osp.join(os.getcwd(), "model_split3_FT_MOT17.pth")
23
+
24
+ model = fasterrcnn_resnet50_fpn()
25
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
26
+ model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
27
+ checkpoint = torch.load(path, map_location="cpu")
28
+ model.load_state_dict(checkpoint["model"])
29
  model.to(device)
30
  model.eval()
31
  return model
 
56
  title = "Domain shift adaption on pedestrian detection with Faster R-CNN"
57
  description = '<p style="text-align:center">School in AI: Deep Learning, Vision and Language for Industry - second edition final project work by Matteo Sirri.</p> '
58
  examples = ["001.jpg", "002.jpg", "003.jpg",
59
+ "004.jpg", "005.jpg", "006.jpg", "007.jpg"]
60
 
61
  io_baseline = gr.Interface(frcnn_coco, gr.Image(type="pil"), gr.Image(
62
  type="file", label="Baseline Model trained on COCO + FT on MOT17"))
model_baseline_coco_FT_MOT17.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7647a5e485632e739efd6444b743d191c5f32916a05988c4a6646c40d3cce47
3
+ size 330056483