Matteo Sirri commited on
Commit
b4616bc
·
1 Parent(s): 0888435

fix: fix typo

Browse files
app.py CHANGED
@@ -62,7 +62,7 @@ io_baseline = gr.Interface(frcnn_coco, gr.Image(type="pil"), gr.Image(
62
  type="file", label="Baseline Model trained on COCO + FT on MOT17"))
63
 
64
  io_custom = gr.Interface(frcnn_motsynth, gr.Image(type="pil"), gr.Image(
65
- type="pil", label="Faster R-CNN trained on MOTSynth + FT on MOT17"))
66
 
67
  gr.Parallel(io_baseline, io_custom, title=title,
68
  description=description, examples=examples).launch(enable_queue=True)
 
62
  type="file", label="Baseline Model trained on COCO + FT on MOT17"))
63
 
64
  io_custom = gr.Interface(frcnn_motsynth, gr.Image(type="pil"), gr.Image(
65
+ type="file", label="Faster R-CNN trained on MOTSynth + FT on MOT17"))
66
 
67
  gr.Parallel(io_baseline, io_custom, title=title,
68
  description=description, examples=examples).launch(enable_queue=True)
app.py.7d07da4a5b8438bc3eb3c4039d0839bb.tmp DELETED
@@ -1,68 +0,0 @@
1
- import os.path as osp
2
- import os
3
- import gradio as gr
4
- import torch
5
- import logging
6
- import torchvision
7
- from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn
8
- from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
9
- from src.detection.graph_utils import add_bbox
10
- from src.detection.vision import presets
11
- import torchvision.transforms as T
12
-
13
- logging.getLogger('PIL').setLevel(logging.CRITICAL)
14
-
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
-
17
-
18
- def load_model(baseline: bool = False):
19
- if baseline:
20
- model = fasterrcnn_resnet50_fpn(
21
- weights="DEFAULT")
22
- else:
23
- model = fasterrcnn_resnet50_fpn()
24
- in_features = model.roi_heads.box_predictor.cls_score.in_features
25
- model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
26
- checkpoint = torch.load(
27
- osp.join(os.getcwd(), "model_split3_FT_MOT17.pth"), map_location="cpu")
28
- model.load_state_dict(checkpoint["model"])
29
- model.to(device)
30
- model.eval()
31
- return model
32
-
33
-
34
- def frcnn_motsynth(image):
35
- model = load_model()
36
- transformEval = presets.DetectionPresetEval()
37
- image_tensor = transformEval(image, None)[0]
38
- image_tensor = image_tensor.to(device)
39
- prediction = model([image_tensor])[0]
40
- image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
41
- torchvision.io.write_png(image_w_bbox, "custom_out.png")
42
- return "custom_out.png"
43
-
44
-
45
- def frcnn_coco(image):
46
- model = load_model(baseline=True)
47
- transformEval = presets.DetectionPresetEval()
48
- image_tensor = transformEval(image, None)[0]
49
- image_tensor = image_tensor.to(device)
50
- prediction = model([image_tensor])[0]
51
- image_w_bbox = add_bbox(image_tensor, prediction, 0.80)
52
- torchvision.io.write_png(image_w_bbox, "baseline_out.png")
53
- return "baseline_out.png"
54
-
55
-
56
- title = "Domain shift adaption on pedestrian detection with Faster R-CNN"
57
- description = '<p style="text-align:center">School in AI: Deep Learning, Vision and Language for Industry - second edition final project work by Matteo Sirri.</p> '
58
- examples = ["001.jpg", "002.jpg", "003.jpg",
59
- "004.jpg", "005.jpg", "006.jpg", "007.jpg", ]
60
-
61
- io_baseline = gr.Interface(frcnn_coco, gr.Image(type="pil"), gr.Image(
62
- type="file", label="Baseline Model trained on COCO + FT on MOT17"))
63
-
64
- io_custom = gr.Interface(frcnn_motsynth, gr.Image(type="pil"), gr.Image(
65
- type="file", label="Faster R-CNN trained on MOTSynth + FT on MOT17"))
66
-
67
- gr.Parallel(io_baseline, io_custom, title=title,
68
- description=description, examples=examples).launch(enable_queue=True)