Spaces:
Running
Running
add possibility to use two networks
Browse files- app.py +10 -6
- resnet101/checkpoint0024.pth +3 -0
- checkpoint0024.pth β resnet50/checkpoint0024.pth +0 -0
- test.py +3 -3
app.py
CHANGED
@@ -13,8 +13,8 @@ from models.preprocessing import *
|
|
13 |
from models.misc import nested_tensor_from_tensor_list
|
14 |
|
15 |
|
16 |
-
model = create_letr()
|
17 |
-
|
18 |
# PREPARE PREPROCESSING
|
19 |
# transform_test = transforms.Compose([
|
20 |
# transforms.Resize((test_size)),
|
@@ -38,7 +38,7 @@ normalize_1100 = Compose([
|
|
38 |
])
|
39 |
|
40 |
|
41 |
-
def predict(inp, size):
|
42 |
image = Image.fromarray(inp.astype('uint8'), 'RGB')
|
43 |
h, w = image.height, image.width
|
44 |
orig_size = torch.as_tensor([int(h), int(w)])
|
@@ -52,7 +52,10 @@ def predict(inp, size):
|
|
52 |
inputs = nested_tensor_from_tensor_list([img])
|
53 |
|
54 |
with torch.no_grad():
|
55 |
-
|
|
|
|
|
|
|
56 |
|
57 |
draw_fig(image, outputs, orig_size)
|
58 |
|
@@ -62,6 +65,7 @@ def predict(inp, size):
|
|
62 |
inputs = [
|
63 |
gr.inputs.Image(),
|
64 |
gr.inputs.Radio(["256", "512", "1100"]),
|
|
|
65 |
]
|
66 |
outputs = gr.outputs.Image()
|
67 |
gr.Interface(
|
@@ -69,8 +73,8 @@ gr.Interface(
|
|
69 |
inputs=inputs,
|
70 |
outputs=outputs,
|
71 |
examples=[
|
72 |
-
["demo.png", '256'],
|
73 |
-
["tappeto-per-calibrazione.jpg", '256']
|
74 |
],
|
75 |
title="LETR",
|
76 |
description="Model for line detection..."
|
|
|
13 |
from models.misc import nested_tensor_from_tensor_list
|
14 |
|
15 |
|
16 |
+
model = create_letr('resnet50/checkpoint0024.pth')
|
17 |
+
model101 = create_letr('resnet101/checkpoint0024.pth')
|
18 |
# PREPARE PREPROCESSING
|
19 |
# transform_test = transforms.Compose([
|
20 |
# transforms.Resize((test_size)),
|
|
|
38 |
])
|
39 |
|
40 |
|
41 |
+
def predict(inp, size, model_name):
|
42 |
image = Image.fromarray(inp.astype('uint8'), 'RGB')
|
43 |
h, w = image.height, image.width
|
44 |
orig_size = torch.as_tensor([int(h), int(w)])
|
|
|
52 |
inputs = nested_tensor_from_tensor_list([img])
|
53 |
|
54 |
with torch.no_grad():
|
55 |
+
if model_name == 'resnet101':
|
56 |
+
outputs = model101(inputs)[0]
|
57 |
+
else:
|
58 |
+
outputs = model(inputs)[0]
|
59 |
|
60 |
draw_fig(image, outputs, orig_size)
|
61 |
|
|
|
65 |
inputs = [
|
66 |
gr.inputs.Image(),
|
67 |
gr.inputs.Radio(["256", "512", "1100"]),
|
68 |
+
gr.inputs.Radio(["resnet50", "resnet101"]),
|
69 |
]
|
70 |
outputs = gr.outputs.Image()
|
71 |
gr.Interface(
|
|
|
73 |
inputs=inputs,
|
74 |
outputs=outputs,
|
75 |
examples=[
|
76 |
+
["demo.png", '256', "resnet50"],
|
77 |
+
["tappeto-per-calibrazione.jpg", '256', "resnet50"]
|
78 |
],
|
79 |
title="LETR",
|
80 |
description="Model for line detection..."
|
resnet101/checkpoint0024.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab07a2ddf5e088540941c755e2cccc66e019dcf94b3ef488bab25f5a76490bb9
|
3 |
+
size 457215616
|
checkpoint0024.pth β resnet50/checkpoint0024.pth
RENAMED
File without changes
|
test.py
CHANGED
@@ -7,9 +7,9 @@ from models.letr import build
|
|
7 |
from models.misc import nested_tensor_from_tensor_list
|
8 |
from models.preprocessing import Compose, ToTensor, Resize, Normalize
|
9 |
|
10 |
-
def create_letr():
|
11 |
# obtain checkpoints
|
12 |
-
checkpoint = torch.load(
|
13 |
|
14 |
# load model
|
15 |
args = checkpoint['args']
|
@@ -44,7 +44,7 @@ def draw_fig(image, outputs, orig_size):
|
|
44 |
draw.line((x1, y1, x2, y2), fill=500)
|
45 |
|
46 |
if __name__ == '__main__':
|
47 |
-
model = create_letr()
|
48 |
|
49 |
test_size = 256
|
50 |
normalize = Compose([
|
|
|
7 |
from models.misc import nested_tensor_from_tensor_list
|
8 |
from models.preprocessing import Compose, ToTensor, Resize, Normalize
|
9 |
|
10 |
+
def create_letr(path):
|
11 |
# obtain checkpoints
|
12 |
+
checkpoint = torch.load(path, map_location='cpu')
|
13 |
|
14 |
# load model
|
15 |
args = checkpoint['args']
|
|
|
44 |
draw.line((x1, y1, x2, y2), fill=500)
|
45 |
|
46 |
if __name__ == '__main__':
|
47 |
+
model = create_letr('resnet50/checkpoint0024.pth')
|
48 |
|
49 |
test_size = 256
|
50 |
normalize = Compose([
|