z-uo commited on
Commit
b6a4ee3
β€’
1 Parent(s): e708547

add possibility to use two networks

Browse files
app.py CHANGED
@@ -13,8 +13,8 @@ from models.preprocessing import *
13
  from models.misc import nested_tensor_from_tensor_list
14
 
15
 
16
- model = create_letr()
17
-
18
  # PREPARE PREPROCESSING
19
  # transform_test = transforms.Compose([
20
  # transforms.Resize((test_size)),
@@ -38,7 +38,7 @@ normalize_1100 = Compose([
38
  ])
39
 
40
 
41
- def predict(inp, size):
42
  image = Image.fromarray(inp.astype('uint8'), 'RGB')
43
  h, w = image.height, image.width
44
  orig_size = torch.as_tensor([int(h), int(w)])
@@ -52,7 +52,10 @@ def predict(inp, size):
52
  inputs = nested_tensor_from_tensor_list([img])
53
 
54
  with torch.no_grad():
55
- outputs = model(inputs)[0]
 
 
 
56
 
57
  draw_fig(image, outputs, orig_size)
58
 
@@ -62,6 +65,7 @@ def predict(inp, size):
62
  inputs = [
63
  gr.inputs.Image(),
64
  gr.inputs.Radio(["256", "512", "1100"]),
 
65
  ]
66
  outputs = gr.outputs.Image()
67
  gr.Interface(
@@ -69,8 +73,8 @@ gr.Interface(
69
  inputs=inputs,
70
  outputs=outputs,
71
  examples=[
72
- ["demo.png", '256'],
73
- ["tappeto-per-calibrazione.jpg", '256']
74
  ],
75
  title="LETR",
76
  description="Model for line detection..."
 
13
  from models.misc import nested_tensor_from_tensor_list
14
 
15
 
16
+ model = create_letr('resnet50/checkpoint0024.pth')
17
+ model101 = create_letr('resnet101/checkpoint0024.pth')
18
  # PREPARE PREPROCESSING
19
  # transform_test = transforms.Compose([
20
  # transforms.Resize((test_size)),
 
38
  ])
39
 
40
 
41
+ def predict(inp, size, model_name):
42
  image = Image.fromarray(inp.astype('uint8'), 'RGB')
43
  h, w = image.height, image.width
44
  orig_size = torch.as_tensor([int(h), int(w)])
 
52
  inputs = nested_tensor_from_tensor_list([img])
53
 
54
  with torch.no_grad():
55
+ if model_name == 'resnet101':
56
+ outputs = model101(inputs)[0]
57
+ else:
58
+ outputs = model(inputs)[0]
59
 
60
  draw_fig(image, outputs, orig_size)
61
 
 
65
  inputs = [
66
  gr.inputs.Image(),
67
  gr.inputs.Radio(["256", "512", "1100"]),
68
+ gr.inputs.Radio(["resnet50", "resnet101"]),
69
  ]
70
  outputs = gr.outputs.Image()
71
  gr.Interface(
 
73
  inputs=inputs,
74
  outputs=outputs,
75
  examples=[
76
+ ["demo.png", '256', "resnet50"],
77
+ ["tappeto-per-calibrazione.jpg", '256', "resnet50"]
78
  ],
79
  title="LETR",
80
  description="Model for line detection..."
resnet101/checkpoint0024.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab07a2ddf5e088540941c755e2cccc66e019dcf94b3ef488bab25f5a76490bb9
3
+ size 457215616
checkpoint0024.pth β†’ resnet50/checkpoint0024.pth RENAMED
File without changes
test.py CHANGED
@@ -7,9 +7,9 @@ from models.letr import build
7
  from models.misc import nested_tensor_from_tensor_list
8
  from models.preprocessing import Compose, ToTensor, Resize, Normalize
9
 
10
- def create_letr():
11
  # obtain checkpoints
12
- checkpoint = torch.load('checkpoint0024.pth', map_location='cpu')
13
 
14
  # load model
15
  args = checkpoint['args']
@@ -44,7 +44,7 @@ def draw_fig(image, outputs, orig_size):
44
  draw.line((x1, y1, x2, y2), fill=500)
45
 
46
  if __name__ == '__main__':
47
- model = create_letr()
48
 
49
  test_size = 256
50
  normalize = Compose([
 
7
  from models.misc import nested_tensor_from_tensor_list
8
  from models.preprocessing import Compose, ToTensor, Resize, Normalize
9
 
10
+ def create_letr(path):
11
  # obtain checkpoints
12
+ checkpoint = torch.load(path, map_location='cpu')
13
 
14
  # load model
15
  args = checkpoint['args']
 
44
  draw.line((x1, y1, x2, y2), fill=500)
45
 
46
  if __name__ == '__main__':
47
+ model = create_letr('resnet50/checkpoint0024.pth')
48
 
49
  test_size = 256
50
  normalize = Compose([