MuGeminorum commited on
Commit
e1e7fa2
1 Parent(s): b57689f

upl base codes

Browse files
Files changed (7) hide show
  1. .gitattributes +10 -11
  2. .gitignore +3 -0
  3. 457.png +0 -0
  4. README.md +23 -11
  5. app.py +99 -0
  6. model.py +178 -0
  7. requirements.txt +4 -0
.gitattributes CHANGED
@@ -1,35 +1,34 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
34
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/*
2
+ *.pth
3
+ flagged/*
457.png ADDED
README.md CHANGED
@@ -1,13 +1,25 @@
1
  ---
2
- title: SVHN Recognition
3
- emoji:
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 4.12.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
 
 
 
 
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
1
  ---
2
+ # 详细文档见https://modelscope.cn/docs/%E5%88%9B%E7%A9%BA%E9%97%B4%E5%8D%A1%E7%89%87
3
+ domain: #领域:cv/nlp/audio/multi-modal/AutoML
4
+ # - cv
5
+ tags: #自定义标签
6
+ -
7
+ datasets: #关联数据集
8
+ evaluation:
9
+ #- damotest/beans
10
+ test:
11
+ #- damotest/squad
12
+ train:
13
+ #- modelscope/coco_2014_caption
14
+ models: #关联模型
15
+ #- damo/speech_charctc_kws_phone-xiaoyunxiaoyun
16
 
17
+ ## 启动文件(若SDK为Gradio/Streamlit,默认为app.py, 若为Static HTML, 默认为index.html)
18
+ # deployspec:
19
+ # entry_file: app.py
20
+ license: MIT License
21
+ ---
22
+ #### Clone with HTTP
23
+ ```bash
24
+ git clone https://www.modelscope.cn/studios/MuGeminorum/SVHN-Recognition.git
25
+ ```
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import requests
4
+ import gradio as gr
5
+ from tqdm import tqdm
6
+ from PIL import Image
7
+ from model import Model
8
+ from torchvision import transforms
9
+ import warnings
10
+ warnings.filterwarnings("ignore")
11
+
12
+
13
+ def download_model(url="https://www.modelscope.cn/api/v1/models/MuGeminorum/SVHN-Recognition/repo?Revision=master&FilePath=model-122000.pth", local_path="model-122000.pth"):
14
+ # Check if the file exists
15
+ if not os.path.exists(local_path):
16
+ print(f"Downloading file from {url}...")
17
+ # Make a request to the URL
18
+ response = requests.get(url, stream=True)
19
+
20
+ # Get the total file size in bytes
21
+ total_size = int(response.headers.get('content-length', 0))
22
+
23
+ # Initialize the tqdm progress bar
24
+ progress_bar = tqdm(total=total_size, unit='B', unit_scale=True)
25
+
26
+ # Open a local file with write-binary mode
27
+ with open(local_path, 'wb') as file:
28
+ for data in response.iter_content(chunk_size=1024):
29
+ # Update the progress bar
30
+ progress_bar.update(len(data))
31
+
32
+ # Write the data to the local file
33
+ file.write(data)
34
+
35
+ # Close the progress bar
36
+ progress_bar.close()
37
+
38
+ print("Download completed.")
39
+
40
+
41
+ def _infer(path_to_checkpoint_file, path_to_input_image):
42
+ model = Model()
43
+ model.restore(path_to_checkpoint_file)
44
+ model.cuda()
45
+ outstr = ''
46
+
47
+ with torch.no_grad():
48
+ transform = transforms.Compose([
49
+ transforms.Resize([64, 64]),
50
+ transforms.CenterCrop([54, 54]),
51
+ transforms.ToTensor(),
52
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
53
+ ])
54
+
55
+ image = Image.open(path_to_input_image)
56
+ image = image.convert('RGB')
57
+ image = transform(image)
58
+ images = image.unsqueeze(dim=0).cuda()
59
+
60
+ length_logits, digit1_logits, digit2_logits, digit3_logits, digit4_logits, digit5_logits = model.eval()(images)
61
+
62
+ length_prediction = length_logits.max(1)[1]
63
+ digit1_prediction = digit1_logits.max(1)[1]
64
+ digit2_prediction = digit2_logits.max(1)[1]
65
+ digit3_prediction = digit3_logits.max(1)[1]
66
+ digit4_prediction = digit4_logits.max(1)[1]
67
+ digit5_prediction = digit5_logits.max(1)[1]
68
+
69
+ output = [
70
+ digit1_prediction.item(),
71
+ digit2_prediction.item(),
72
+ digit3_prediction.item(),
73
+ digit4_prediction.item(),
74
+ digit5_prediction.item()
75
+ ]
76
+
77
+ for i in range(length_prediction.item()):
78
+ outstr += str(output[i])
79
+
80
+ return outstr
81
+
82
+
83
+ def inference(image_path, weight_path="model-122000.pth"):
84
+ download_model()
85
+
86
+ if not image_path:
87
+ image_path = '457.png'
88
+
89
+ return _infer(weight_path, image_path)
90
+
91
+
92
+ if __name__ == '__main__':
93
+ iface = gr.Interface(
94
+ fn=inference,
95
+ inputs=gr.Image(type='filepath'),
96
+ outputs=gr.Textbox()
97
+ )
98
+
99
+ iface.launch()
model.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import torch
4
+ import torch.jit
5
+ import torch.nn as nn
6
+
7
+
8
+ class Model(torch.jit.ScriptModule):
9
+ CHECKPOINT_FILENAME_PATTERN = 'model-{}.pth'
10
+
11
+ __constants__ = [
12
+ '_hidden1', '_hidden2', '_hidden3', '_hidden4', '_hidden5', '_hidden6',
13
+ '_hidden7', '_hidden8', '_hidden9', '_hidden10', '_features', '_classifier',
14
+ '_digit_length', '_digit1', '_digit2', '_digit3', '_digit4', '_digit5'
15
+ ]
16
+
17
+ def __init__(self):
18
+ super(Model, self).__init__()
19
+
20
+ self._hidden1 = nn.Sequential(
21
+ nn.Conv2d(
22
+ in_channels=3,
23
+ out_channels=48,
24
+ kernel_size=5,
25
+ padding=2
26
+ ),
27
+ nn.BatchNorm2d(num_features=48),
28
+ nn.ReLU(),
29
+ nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
30
+ nn.Dropout(0.2)
31
+ )
32
+ self._hidden2 = nn.Sequential(
33
+ nn.Conv2d(
34
+ in_channels=48,
35
+ out_channels=64,
36
+ kernel_size=5,
37
+ padding=2
38
+ ),
39
+ nn.BatchNorm2d(num_features=64),
40
+ nn.ReLU(),
41
+ nn.MaxPool2d(kernel_size=2, stride=1, padding=1),
42
+ nn.Dropout(0.2)
43
+ )
44
+ self._hidden3 = nn.Sequential(
45
+ nn.Conv2d(
46
+ in_channels=64,
47
+ out_channels=128,
48
+ kernel_size=5,
49
+ padding=2
50
+ ),
51
+ nn.BatchNorm2d(num_features=128),
52
+ nn.ReLU(),
53
+ nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
54
+ nn.Dropout(0.2)
55
+ )
56
+ self._hidden4 = nn.Sequential(
57
+ nn.Conv2d(
58
+ in_channels=128,
59
+ out_channels=160,
60
+ kernel_size=5,
61
+ padding=2
62
+ ),
63
+ nn.BatchNorm2d(num_features=160),
64
+ nn.ReLU(),
65
+ nn.MaxPool2d(kernel_size=2, stride=1, padding=1),
66
+ nn.Dropout(0.2)
67
+ )
68
+ self._hidden5 = nn.Sequential(
69
+ nn.Conv2d(
70
+ in_channels=160,
71
+ out_channels=192,
72
+ kernel_size=5,
73
+ padding=2
74
+ ),
75
+ nn.BatchNorm2d(num_features=192),
76
+ nn.ReLU(),
77
+ nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
78
+ nn.Dropout(0.2)
79
+ )
80
+ self._hidden6 = nn.Sequential(
81
+ nn.Conv2d(
82
+ in_channels=192,
83
+ out_channels=192,
84
+ kernel_size=5,
85
+ padding=2
86
+ ),
87
+ nn.BatchNorm2d(num_features=192),
88
+ nn.ReLU(),
89
+ nn.MaxPool2d(kernel_size=2, stride=1, padding=1),
90
+ nn.Dropout(0.2)
91
+ )
92
+ self._hidden7 = nn.Sequential(
93
+ nn.Conv2d(
94
+ in_channels=192,
95
+ out_channels=192,
96
+ kernel_size=5,
97
+ padding=2
98
+ ),
99
+ nn.BatchNorm2d(num_features=192),
100
+ nn.ReLU(),
101
+ nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
102
+ nn.Dropout(0.2)
103
+ )
104
+ self._hidden8 = nn.Sequential(
105
+ nn.Conv2d(
106
+ in_channels=192,
107
+ out_channels=192,
108
+ kernel_size=5,
109
+ padding=2
110
+ ),
111
+ nn.BatchNorm2d(num_features=192),
112
+ nn.ReLU(),
113
+ nn.MaxPool2d(kernel_size=2, stride=1, padding=1),
114
+ nn.Dropout(0.2)
115
+ )
116
+ self._hidden9 = nn.Sequential(
117
+ nn.Linear(192 * 7 * 7, 3072),
118
+ nn.ReLU()
119
+ )
120
+ self._hidden10 = nn.Sequential(
121
+ nn.Linear(3072, 3072),
122
+ nn.ReLU()
123
+ )
124
+
125
+ self._digit_length = nn.Sequential(nn.Linear(3072, 7))
126
+ self._digit1 = nn.Sequential(nn.Linear(3072, 11))
127
+ self._digit2 = nn.Sequential(nn.Linear(3072, 11))
128
+ self._digit3 = nn.Sequential(nn.Linear(3072, 11))
129
+ self._digit4 = nn.Sequential(nn.Linear(3072, 11))
130
+ self._digit5 = nn.Sequential(nn.Linear(3072, 11))
131
+
132
+ @torch.jit.script_method
133
+ def forward(self, x):
134
+ x = self._hidden1(x)
135
+ x = self._hidden2(x)
136
+ x = self._hidden3(x)
137
+ x = self._hidden4(x)
138
+ x = self._hidden5(x)
139
+ x = self._hidden6(x)
140
+ x = self._hidden7(x)
141
+ x = self._hidden8(x)
142
+ x = x.view(x.size(0), 192 * 7 * 7)
143
+ x = self._hidden9(x)
144
+ x = self._hidden10(x)
145
+
146
+ length_logits = self._digit_length(x)
147
+ digit1_logits = self._digit1(x)
148
+ digit2_logits = self._digit2(x)
149
+ digit3_logits = self._digit3(x)
150
+ digit4_logits = self._digit4(x)
151
+ digit5_logits = self._digit5(x)
152
+
153
+ return length_logits, digit1_logits, digit2_logits, digit3_logits, digit4_logits, digit5_logits
154
+
155
+ def store(self, path_to_dir, step, maximum=5):
156
+ path_to_models = glob.glob(os.path.join(
157
+ path_to_dir, Model.CHECKPOINT_FILENAME_PATTERN.format('*')))
158
+ if len(path_to_models) == maximum:
159
+ min_step = min(
160
+ [int(path_to_model.split('\\')[-1][6:-4])
161
+ for path_to_model in path_to_models]
162
+ )
163
+ path_to_min_step_model = os.path.join(
164
+ path_to_dir,
165
+ Model.CHECKPOINT_FILENAME_PATTERN.format(min_step)
166
+ )
167
+ os.remove(path_to_min_step_model)
168
+
169
+ path_to_checkpoint_file = os.path.join(
170
+ path_to_dir, Model.CHECKPOINT_FILENAME_PATTERN.format(step)
171
+ )
172
+ torch.save(self.state_dict(), path_to_checkpoint_file)
173
+ return path_to_checkpoint_file
174
+
175
+ def restore(self, path_to_checkpoint_file):
176
+ self.load_state_dict(torch.load(path_to_checkpoint_file))
177
+ step = int(path_to_checkpoint_file.split('\\')[-1][6:-4])
178
+ return step
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ pillow
3
+ torch
4
+ torchvision