AAAAAAyq commited on
Commit
ddaa443
1 Parent(s): d852f6a
Files changed (2) hide show
  1. app.py +53 -22
  2. requirements.txt +2 -2
app.py CHANGED
@@ -5,8 +5,7 @@ import gradio as gr
5
  import cv2
6
  import torch
7
  # import queue
8
- # import time
9
-
10
  # from PIL import Image
11
 
12
 
@@ -137,18 +136,51 @@ def fast_show_mask_gpu(annotation, ax,
137
  ax.imshow(show_cpu)
138
 
139
 
140
- # # 建立请求队列和线程同步锁
141
- # request_queue = queue.Queue(maxsize=10)
142
- # lock = queue.Queue()
 
 
 
 
143
 
144
- def predict(input, input_size=512, high_visual_quality=True):
145
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
146
  input_size = int(input_size) # 确保 imgsz 是整数
 
 
 
 
 
 
 
 
 
 
 
147
  results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
148
  fig = fast_process(annotations=results[0].masks.data,
149
  image=input, high_quality=high_visual_quality, device=device)
150
  return fig
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  # # 将耗时的函数包装在另一个函数中,用于控制队列和线程同步
153
  # def process_request():
154
  # while True:
@@ -156,8 +188,8 @@ def predict(input, input_size=512, high_visual_quality=True):
156
  # # 如果请求队列不为空,则处理该请求
157
  # try:
158
  # lock.put(1) # 加锁,防止同时处理多个请求
159
- # input_package = request_queue.get()
160
- # fig = predict(input_package)
161
  # request_queue.task_done() # 请求处理结束,移除请求
162
  # lock.get() # 解锁
163
  # yield fig # 返回预测结果
@@ -179,17 +211,17 @@ def predict(input, input_size=512, high_visual_quality=True):
179
  # image=input, high_quality=high_quality_visual, device=device)
180
  app_interface = gr.Interface(fn=predict,
181
  inputs=[gr.components.Image(type='pil'),
182
- gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64),
183
- gr.components.Checkbox(value=True)],
184
  outputs=['plot'],
185
- # examples=[["assets/sa_8776.jpg", 1024, True]],
186
- # ["assets/sa_1309.jpg", 1024]],
187
- examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
188
- ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
189
- ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
190
- ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
191
- cache_examples=False,
192
- title="Fast Segment Anthing (Everything mode)"
193
  )
194
 
195
  # # 定义一个请求处理函数,将请求添加到队列中
@@ -201,8 +233,7 @@ app_interface = gr.Interface(fn=predict,
201
  # return None
202
 
203
  # # 添加请求处理函数到应用程序界面
204
- # app_interface.add_transition("submit", handle_request)
205
-
206
 
207
- app_interface.queue(concurrency_count=2)
208
  app_interface.launch()
 
5
  import cv2
6
  import torch
7
  # import queue
8
+ # import threading
 
9
  # from PIL import Image
10
 
11
 
 
136
  ax.imshow(show_cpu)
137
 
138
 
139
+ # # 预测队列
140
+ # prediction_queue = queue.Queue(maxsize=5)
141
+
142
+ # # 线程锁
143
+ # lock = threading.Lock()
144
+
145
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
146
 
147
+ def predict(input, input_size=512, high_visual_quality=False):
 
148
  input_size = int(input_size) # 确保 imgsz 是整数
149
+ # # 获取线程锁
150
+ # with lock:
151
+ # print('5')
152
+ # # 将任务添加到队列
153
+ # prediction_queue.put((input, input_size, high_visual_quality))
154
+
155
+ # # 等待结果
156
+ # print('6')
157
+ # fig = prediction_queue.get()[0]
158
+ # print(fig)
159
+ # return fig
160
  results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
161
  fig = fast_process(annotations=results[0].masks.data,
162
  image=input, high_quality=high_visual_quality, device=device)
163
  return fig
164
 
165
+ # def worker():
166
+ # while True:
167
+ # # 从队列获取任务
168
+ # print('1')
169
+ # input, input_size, high_visual_quality = prediction_queue.get()
170
+
171
+ # # 执行模型预测
172
+ # print('2')
173
+ # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
174
+ # print('3')
175
+ # fig = fast_process(annotations=results[0].masks.data,
176
+ # image=input, high_quality=high_visual_quality, device=device)
177
+ # print('4')
178
+ # # 将结果放回队列
179
+ # prediction_queue.put(fig)
180
+
181
+ # # 在一个新的线程中启动工作函数
182
+ # threading.Thread(target=worker).start()
183
+
184
  # # 将耗时的函数包装在另一个函数中,用于控制队列和线程同步
185
  # def process_request():
186
  # while True:
 
188
  # # 如果请求队列不为空,则处理该请求
189
  # try:
190
  # lock.put(1) # 加锁,防止同时处理多个请求
191
+ # input, input_size, high_visual_quality = request_queue.get()
192
+ # fig = predict(input, input_size, high_visual_quality)
193
  # request_queue.task_done() # 请求处理结束,移除请求
194
  # lock.get() # 解锁
195
  # yield fig # 返回预测结果
 
211
  # image=input, high_quality=high_quality_visual, device=device)
212
  app_interface = gr.Interface(fn=predict,
213
  inputs=[gr.components.Image(type='pil'),
214
+ gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='input_size'),
215
+ gr.components.Checkbox(value=False, label='high_visual_quality')],
216
  outputs=['plot'],
217
+ examples=[["assets/sa_8776.jpg", 1024, True]],
218
+ # # ["assets/sa_1309.jpg", 1024]],
219
+ # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
220
+ # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
221
+ # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
222
+ # ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
223
+ cache_examples=True,
224
+ title="Fast Segment Anything (Everything mode)"
225
  )
226
 
227
  # # 定义一个请求处理函数,将请求添加到队列中
 
233
  # return None
234
 
235
  # # 添加请求处理函数到应用程序界面
236
+ # app_interface.call_function()
 
237
 
238
+ app_interface.queue(concurrency_count=1, max_size=20)
239
  app_interface.launch()
requirements.txt CHANGED
@@ -6,8 +6,8 @@ opencv-python
6
  # PyYAML>=5.3.1
7
  # requests>=2.23.0
8
  # scipy>=1.4.1
9
- torch
10
- torchvision
11
  # tqdm>=4.64.0
12
 
13
  # pandas>=1.1.4
 
6
  # PyYAML>=5.3.1
7
  # requests>=2.23.0
8
  # scipy>=1.4.1
9
+ # torch
10
+ # torchvision
11
  # tqdm>=4.64.0
12
 
13
  # pandas>=1.1.4