AAAAAAyq commited on
Commit
d852f6a
1 Parent(s): e0a1444
Files changed (1) hide show
  1. app.py +46 -6
app.py CHANGED
@@ -4,7 +4,11 @@ import matplotlib.pyplot as plt
4
  import gradio as gr
5
  import cv2
6
  import torch
7
- from PIL import Image
 
 
 
 
8
 
9
  model = YOLO('checkpoints/FastSAM.pt') # load a custom model
10
 
@@ -132,15 +136,37 @@ def fast_show_mask_gpu(annotation, ax,
132
  plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
133
  ax.imshow(show_cpu)
134
 
135
- # post_process(results[0].masks, Image.open("../data/cake.png"))
 
 
 
136
 
137
  def predict(input, input_size=512, high_visual_quality=True):
138
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
139
  input_size = int(input_size) # 确保 imgsz 是整数
140
  results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
141
- pil_image = fast_process(annotations=results[0].masks.data,
142
  image=input, high_quality=high_visual_quality, device=device)
143
- return pil_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  # input_size=1024
146
  # high_quality_visual=True
@@ -151,7 +177,7 @@ def predict(input, input_size=512, high_visual_quality=True):
151
  # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
152
  # pil_image = fast_process(annotations=results[0].masks.data,
153
  # image=input, high_quality=high_quality_visual, device=device)
154
- demo = gr.Interface(fn=predict,
155
  inputs=[gr.components.Image(type='pil'),
156
  gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64),
157
  gr.components.Checkbox(value=True)],
@@ -163,6 +189,20 @@ demo = gr.Interface(fn=predict,
163
  ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
164
  ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
165
  cache_examples=False,
 
166
  )
167
 
168
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import gradio as gr
5
  import cv2
6
  import torch
7
+ # import queue
8
+ # import time
9
+
10
+ # from PIL import Image
11
+
12
 
13
  model = YOLO('checkpoints/FastSAM.pt') # load a custom model
14
 
 
136
  plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
137
  ax.imshow(show_cpu)
138
 
139
+
140
+ # # 建立请求队列和线程同步锁
141
+ # request_queue = queue.Queue(maxsize=10)
142
+ # lock = queue.Queue()
143
 
144
  def predict(input, input_size=512, high_visual_quality=True):
145
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
146
  input_size = int(input_size) # 确保 imgsz 是整数
147
  results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
148
+ fig = fast_process(annotations=results[0].masks.data,
149
  image=input, high_quality=high_visual_quality, device=device)
150
+ return fig
151
+
152
+ # # 将耗时的函数包装在另一个函数中,用于控制队列和线程同步
153
+ # def process_request():
154
+ # while True:
155
+ # if not request_queue.empty():
156
+ # # 如果请求队列不为空,则处理该请求
157
+ # try:
158
+ # lock.put(1) # 加锁,防止同时处理多个请求
159
+ # input_package = request_queue.get()
160
+ # fig = predict(input_package)
161
+ # request_queue.task_done() # 请求处理结束,移除请求
162
+ # lock.get() # 解锁
163
+ # yield fig # 返回预测结果
164
+ # except:
165
+ # lock.get() # 出错时也需要解锁
166
+ # else:
167
+ # # 如果请求队列为空,则等待新的请求到达
168
+ # time.sleep(1)
169
+
170
 
171
  # input_size=1024
172
  # high_quality_visual=True
 
177
  # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
178
  # pil_image = fast_process(annotations=results[0].masks.data,
179
  # image=input, high_quality=high_quality_visual, device=device)
180
+ app_interface = gr.Interface(fn=predict,
181
  inputs=[gr.components.Image(type='pil'),
182
  gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64),
183
  gr.components.Checkbox(value=True)],
 
189
  ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
190
  ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
191
  cache_examples=False,
192
+ title="Fast Segment Anthing (Everything mode)"
193
  )
194
 
195
+ # # 定义一个请求处理函数,将请求添加到队列中
196
+ # def handle_request(value):
197
+ # try:
198
+ # request_queue.put_nowait(value) # 添加请求到队列
199
+ # except:
200
+ # return "当前队列已满,请稍后再试!"
201
+ # return None
202
+
203
+ # # 添加请求处理函数到应用程序界面
204
+ # app_interface.add_transition("submit", handle_request)
205
+
206
+
207
+ app_interface.queue(concurrency_count=2)
208
+ app_interface.launch()