MuGeminorum commited on
Commit
de12ba7
1 Parent(s): 58e5f22

add show copy btn

Browse files
Files changed (1) hide show
  1. app.py +33 -24
app.py CHANGED
@@ -7,10 +7,14 @@ from PIL import Image
7
  from model import Model
8
  from torchvision import transforms
9
  import warnings
 
10
  warnings.filterwarnings("ignore")
11
 
12
 
13
- def download_model(url="https://huggingface.co/MuGeminorum/SVHN-Recognition/resolve/main/model-122000.pth", local_path="model-122000.pth"):
 
 
 
14
  # Check if the file exists
15
  if not os.path.exists(local_path):
16
  print(f"Downloading file from {url}...")
@@ -18,13 +22,13 @@ def download_model(url="https://huggingface.co/MuGeminorum/SVHN-Recognition/reso
18
  response = requests.get(url, stream=True)
19
 
20
  # Get the total file size in bytes
21
- total_size = int(response.headers.get('content-length', 0))
22
 
23
  # Initialize the tqdm progress bar
24
- progress_bar = tqdm(total=total_size, unit='B', unit_scale=True)
25
 
26
  # Open a local file with write-binary mode
27
- with open(local_path, 'wb') as file:
28
  for data in response.iter_content(chunk_size=1024):
29
  # Update the progress bar
30
  progress_bar.update(len(data))
@@ -42,22 +46,31 @@ def _infer(path_to_checkpoint_file, path_to_input_image):
42
  model = Model()
43
  model.restore(path_to_checkpoint_file)
44
  # model.cuda()
45
- outstr = ''
46
 
47
  with torch.no_grad():
48
- transform = transforms.Compose([
49
- transforms.Resize([64, 64]),
50
- transforms.CenterCrop([54, 54]),
51
- transforms.ToTensor(),
52
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
53
- ])
 
 
54
 
55
  image = Image.open(path_to_input_image)
56
- image = image.convert('RGB')
57
  image = transform(image)
58
  images = image.unsqueeze(dim=0) # .cuda()
59
 
60
- length_logits, digit1_logits, digit2_logits, digit3_logits, digit4_logits, digit5_logits = model.eval()(images)
 
 
 
 
 
 
 
61
 
62
  length_prediction = length_logits.max(1)[1]
63
  digit1_prediction = digit1_logits.max(1)[1]
@@ -71,7 +84,7 @@ def _infer(path_to_checkpoint_file, path_to_input_image):
71
  digit2_prediction.item(),
72
  digit3_prediction.item(),
73
  digit4_prediction.item(),
74
- digit5_prediction.item()
75
  ]
76
 
77
  for i in range(length_prediction.item()):
@@ -89,23 +102,19 @@ def inference(image_path, weight_path="model-122000.pth"):
89
  )
90
 
91
  if not image_path:
92
- image_path = './examples/03.png'
93
 
94
  return _infer(weight_path, image_path)
95
 
96
 
97
- if __name__ == '__main__':
98
- example_images = [
99
- './examples/03.png',
100
- './examples/457.png',
101
- './examples/2003.png'
102
- ]
103
 
104
  iface = gr.Interface(
105
  fn=inference,
106
- inputs=gr.Image(type='filepath', label='Upload photo'),
107
- outputs=gr.Textbox(label='Recognition result'),
108
- examples=example_images
109
  )
110
 
111
  iface.launch()
 
7
  from model import Model
8
  from torchvision import transforms
9
  import warnings
10
+
11
  warnings.filterwarnings("ignore")
12
 
13
 
14
+ def download_model(
15
+ url="https://huggingface.co/MuGeminorum/SVHN-Recognition/resolve/main/model-122000.pth",
16
+ local_path="model-122000.pth",
17
+ ):
18
  # Check if the file exists
19
  if not os.path.exists(local_path):
20
  print(f"Downloading file from {url}...")
 
22
  response = requests.get(url, stream=True)
23
 
24
  # Get the total file size in bytes
25
+ total_size = int(response.headers.get("content-length", 0))
26
 
27
  # Initialize the tqdm progress bar
28
+ progress_bar = tqdm(total=total_size, unit="B", unit_scale=True)
29
 
30
  # Open a local file with write-binary mode
31
+ with open(local_path, "wb") as file:
32
  for data in response.iter_content(chunk_size=1024):
33
  # Update the progress bar
34
  progress_bar.update(len(data))
 
46
  model = Model()
47
  model.restore(path_to_checkpoint_file)
48
  # model.cuda()
49
+ outstr = ""
50
 
51
  with torch.no_grad():
52
+ transform = transforms.Compose(
53
+ [
54
+ transforms.Resize([64, 64]),
55
+ transforms.CenterCrop([54, 54]),
56
+ transforms.ToTensor(),
57
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
58
+ ]
59
+ )
60
 
61
  image = Image.open(path_to_input_image)
62
+ image = image.convert("RGB")
63
  image = transform(image)
64
  images = image.unsqueeze(dim=0) # .cuda()
65
 
66
+ (
67
+ length_logits,
68
+ digit1_logits,
69
+ digit2_logits,
70
+ digit3_logits,
71
+ digit4_logits,
72
+ digit5_logits,
73
+ ) = model.eval()(images)
74
 
75
  length_prediction = length_logits.max(1)[1]
76
  digit1_prediction = digit1_logits.max(1)[1]
 
84
  digit2_prediction.item(),
85
  digit3_prediction.item(),
86
  digit4_prediction.item(),
87
+ digit5_prediction.item(),
88
  ]
89
 
90
  for i in range(length_prediction.item()):
 
102
  )
103
 
104
  if not image_path:
105
+ image_path = "./examples/03.png"
106
 
107
  return _infer(weight_path, image_path)
108
 
109
 
110
+ if __name__ == "__main__":
111
+ example_images = ["./examples/03.png", "./examples/457.png", "./examples/2003.png"]
 
 
 
 
112
 
113
  iface = gr.Interface(
114
  fn=inference,
115
+ inputs=gr.Image(type="filepath", label="Upload photo"),
116
+ outputs=gr.Textbox(label="Recognition result", show_copy_button=True),
117
+ examples=example_images,
118
  )
119
 
120
  iface.launch()