liuyizhang commited on
Commit
82b6069
1 Parent(s): c2a6c29

support gradio & api

Browse files
Files changed (3) hide show
  1. api_client.py +69 -0
  2. app.py +197 -65
  3. requirements.txt +1 -5
api_client.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests, json
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import base64
6
+ import io
7
+
8
+ def request_post(url, data, timeout=600, headers = None):
9
+ if headers is None:
10
+ headers = {
11
+ # 'content-type': 'application/json'
12
+ # 'Connection': 'keep-alive',
13
+ 'Accept': '*/*', # 接受任何类型的返回数据
14
+ 'Content-Type': 'application/json;charset=UTF-8', # 发送数据为json
15
+ # 'Content-Length': '156',
16
+ # 'Accept-Encoding': 'gzip, deflate',
17
+ # 'Accept-Language': 'zh-CN,zh;q=0.9',
18
+ # 'User-Agent': 'SamClub/5.0.45 (iPhone; iOS 15.4; Scale/3.00)',
19
+ # 'device-name': 'iPhone14,3',
20
+ # 'device-os-version': '15.4',
21
+ # 'device-type': 'ios',
22
+ # 'auth-token': authtoken,
23
+ # 'app-version': '5.0.45.1'
24
+ }
25
+ try:
26
+ response = requests.post(url=url, headers=headers, data=json.dumps(data), timeout=timeout)
27
+ response_data = response.json()
28
+ return response_data
29
+ except Exception as e:
30
+ print(f'request_post[Error]:' + str(e))
31
+ print(f'url: {url}')
32
+ print(f'data: {data}')
33
+ print(f'response: {response}')
34
+ return None
35
+
36
+ url = "http://127.0.0.1:7860/imgCLeaner"
37
+
38
+ def imgFile_to_base64(image_file):
39
+ with open(image_file, "rb") as f:
40
+ im_bytes = f.read()
41
+ im_b64_encode = base64.b64encode(im_bytes)
42
+ im_b64 = im_b64_encode.decode("utf8")
43
+ return im_b64
44
+
45
+ def base64_to_bytes(im_b64):
46
+ im_b64_encode = im_b64.encode("utf-8")
47
+ im_bytes = base64.b64decode(im_b64_encode)
48
+ return im_bytes
49
+
50
+ def base64_to_PILImage(im_b64):
51
+ im_bytes = base64_to_bytes(im_b64)
52
+ pil_img = Image.open(io.BytesIO(im_bytes))
53
+ return pil_img
54
+
55
+ image_file = 'dog.png'
56
+ data = {'remove_texts': "小狗 . 椅子",
57
+ 'extend': 20,
58
+ 'img': imgFile_to_base64(image_file),
59
+ }
60
+
61
+ ret = request_post(url, data, timeout=600, headers = None)
62
+ print(len(ret['result']['imgs']))
63
+
64
+ for img in ret['result']['imgs']:
65
+ pilImage = base64_to_PILImage(img)
66
+ plt.imshow(pilImage)
67
+ plt.show()
68
+ plt.clf()
69
+
app.py CHANGED
@@ -120,7 +120,6 @@ ram_model = None
120
  kosmos_model = None
121
  kosmos_processor = None
122
 
123
-
124
  def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
125
  args = SLConfig.fromfile(model_config_path)
126
  model = build_model(args)
@@ -621,7 +620,8 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
621
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
622
 
623
  size = image_pil.size
624
-
 
625
  # run grounding dino model
626
  if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
627
  pass
@@ -655,25 +655,35 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
655
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
656
  if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
657
  image = np.array(input_img)
658
- sam_predictor.set_image(image)
 
659
 
660
- H, W = size[1], size[0]
661
  for i in range(boxes_filt.size(0)):
662
  boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
663
  boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
664
  boxes_filt[i][2:] += boxes_filt[i][:2]
665
 
666
- boxes_filt = boxes_filt.to(sam_device)
667
- transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
668
 
669
- masks, _, _, _ = sam_predictor.predict_torch(
670
- point_coords = None,
671
- point_labels = None,
672
- boxes = transformed_boxes,
673
- multimask_output = False,
674
- )
675
- # masks: [9, 1, 512, 512]
676
- assert sam_checkpoint, 'sam_checkpoint is not found!'
677
  # draw output image
678
  plt.figure(figsize=(10, 10))
679
  plt.imshow(image)
@@ -686,7 +696,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
686
  plt.savefig(image_path, bbox_inches="tight")
687
  segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
688
  os.remove(image_path)
689
- output_images.append(segment_image_result)
690
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
691
 
692
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
@@ -705,9 +715,9 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
705
  masks_ori = copy.deepcopy(masks)
706
  if inpaint_mode == 'merge':
707
  masks = torch.sum(masks, dim=0).unsqueeze(0)
708
- masks = torch.where(masks > 0, True, False)
709
  mask = masks[0][0].cpu().numpy()
710
- mask_pil = Image.fromarray(mask)
711
  output_images.append(mask_pil.convert("RGB"))
712
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
713
 
@@ -718,7 +728,6 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
718
  image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
719
  else:
720
  # remove from mask
721
- logger.info(f'run_anything_task_[{file_temp}]_{task_type}_5_')
722
  if mask_source_radio == mask_source_segment:
723
  mask_imgs = []
724
  masks_shape = masks_ori.shape
@@ -732,19 +741,17 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
732
  for i in range(extend_shape_0):
733
  for j in range(extend_shape_1):
734
  mask = masks_ori[i][j].cpu().numpy()
735
- mask_pil = Image.fromarray(mask)
736
-
737
  if remove_mode == 'segment':
738
  useRectangle = False
739
  else:
740
  useRectangle = True
741
-
742
  try:
743
  remove_mask_extend = int(remove_mask_extend)
744
  except:
745
  remove_mask_extend = 10
746
  mask_pil_exp = mask_extend(copy.deepcopy(mask_pil).convert("RGB"),
747
- xywh_to_xyxy(torch.tensor(boxes_filt_ori_array[i]), size[0], size[1]),
748
  extend_pixels=remove_mask_extend, useRectangle=useRectangle)
749
  mask_imgs.append(mask_pil_exp)
750
  mask_pil = mix_masks(mask_imgs)
@@ -820,48 +827,7 @@ def get_model_device(module):
820
  except Exception as e:
821
  return 'Error'
822
 
823
- if __name__ == "__main__":
824
- parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
825
- parser.add_argument("--debug", action="store_true", help="using debug mode")
826
- parser.add_argument("--share", action="store_true", help="share the app")
827
- args, _ = parser.parse_known_args()
828
- print(f'args = {args}')
829
-
830
- if os.environ.get('IS_MY_DEBUG') is None:
831
- os.system("pip list")
832
-
833
- device = set_device()
834
- if device == 'cpu':
835
- kosmos_enable = False
836
-
837
- if kosmos_enable:
838
- kosmos_model, kosmos_processor = load_kosmos_model(device)
839
-
840
- if groundingdino_enable:
841
- groundingdino_model = load_groundingdino_model('cpu')
842
-
843
- if sam_enable:
844
- load_sam_model(device)
845
-
846
- if inpainting_enable:
847
- load_sd_model(device)
848
-
849
- if lama_cleaner_enable:
850
- load_lama_cleaner_model(device)
851
-
852
- if ram_enable:
853
- load_ram_model(device)
854
-
855
- if os.environ.get('IS_MY_DEBUG') is None:
856
- os.system("pip list")
857
-
858
- # print(f'groundingdino_model__{get_model_device(groundingdino_model)}')
859
- # print(f'sam_model__{get_model_device(sam_model)}')
860
- # print(f'sd_model__{get_model_device(sd_model)}')
861
- # print(f'lama_cleaner_model__{get_model_device(lama_cleaner_model)}')
862
- # print(f'ram_model__{get_model_device(ram_model)}')
863
- # print(f'kosmos_model__{get_model_device(kosmos_model)}')
864
-
865
  block = gr.Blocks().queue()
866
  with block:
867
  with gr.Row():
@@ -968,5 +934,171 @@ if __name__ == "__main__":
968
  print(f'device = {device}')
969
  print(f'torch.cuda.is_available = {torch.cuda.is_available()}')
970
  computer_info()
971
- block.launch(server_name='0.0.0.0', debug=args.debug, share=args.share)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
972
 
 
120
  kosmos_model = None
121
  kosmos_processor = None
122
 
 
123
  def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
124
  args = SLConfig.fromfile(model_config_path)
125
  model = build_model(args)
 
620
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
621
 
622
  size = image_pil.size
623
+ H, W = size[1], size[0]
624
+
625
  # run grounding dino model
626
  if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
627
  pass
 
655
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
656
  if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
657
  image = np.array(input_img)
658
+ if sam_predictor:
659
+ sam_predictor.set_image(image)
660
 
 
661
  for i in range(boxes_filt.size(0)):
662
  boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
663
  boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
664
  boxes_filt[i][2:] += boxes_filt[i][:2]
665
 
666
+ if sam_predictor:
667
+ boxes_filt = boxes_filt.to(sam_device)
668
+ transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
669
+
670
+ masks, _, _, _ = sam_predictor.predict_torch(
671
+ point_coords = None,
672
+ point_labels = None,
673
+ boxes = transformed_boxes,
674
+ multimask_output = False,
675
+ )
676
+ # masks: [9, 1, 512, 512]
677
+ assert sam_checkpoint, 'sam_checkpoint is not found!'
678
+ else:
679
+ masks = torch.zeros(len(boxes_filt), 1, H, W)
680
+ mask_count = 0
681
+ for box in boxes_filt:
682
+ masks[mask_count, 0, int(box[1]):int(box[3]), int(box[0]):int(box[2])] = 1
683
+ mask_count += 1
684
+ masks = torch.where(masks > 0, True, False)
685
+ run_mode = "rectangle"
686
 
 
 
 
 
 
 
 
 
687
  # draw output image
688
  plt.figure(figsize=(10, 10))
689
  plt.imshow(image)
 
696
  plt.savefig(image_path, bbox_inches="tight")
697
  segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
698
  os.remove(image_path)
699
+ output_images.append(Image.fromarray(segment_image_result))
700
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
701
 
702
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
 
715
  masks_ori = copy.deepcopy(masks)
716
  if inpaint_mode == 'merge':
717
  masks = torch.sum(masks, dim=0).unsqueeze(0)
718
+ masks = torch.where(masks > 0, True, False)
719
  mask = masks[0][0].cpu().numpy()
720
+ mask_pil = Image.fromarray(mask)
721
  output_images.append(mask_pil.convert("RGB"))
722
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
723
 
 
728
  image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
729
  else:
730
  # remove from mask
 
731
  if mask_source_radio == mask_source_segment:
732
  mask_imgs = []
733
  masks_shape = masks_ori.shape
 
741
  for i in range(extend_shape_0):
742
  for j in range(extend_shape_1):
743
  mask = masks_ori[i][j].cpu().numpy()
744
+ mask_pil = Image.fromarray(mask)
 
745
  if remove_mode == 'segment':
746
  useRectangle = False
747
  else:
748
  useRectangle = True
 
749
  try:
750
  remove_mask_extend = int(remove_mask_extend)
751
  except:
752
  remove_mask_extend = 10
753
  mask_pil_exp = mask_extend(copy.deepcopy(mask_pil).convert("RGB"),
754
+ xywh_to_xyxy(torch.tensor(boxes_filt_ori_array[i]), W, H),
755
  extend_pixels=remove_mask_extend, useRectangle=useRectangle)
756
  mask_imgs.append(mask_pil_exp)
757
  mask_pil = mix_masks(mask_imgs)
 
827
  except Exception as e:
828
  return 'Error'
829
 
830
+ def main_gradio(args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831
  block = gr.Blocks().queue()
832
  with block:
833
  with gr.Row():
 
934
  print(f'device = {device}')
935
  print(f'torch.cuda.is_available = {torch.cuda.is_available()}')
936
  computer_info()
937
+ block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share)
938
+
939
+ import signal
940
+ import json
941
+ from datetime import date, datetime, timedelta
942
+ from gevent import pywsgi
943
+ import base64
944
+
945
+ def imgFile_to_base64(image_file):
946
+ with open(image_file, "rb") as f:
947
+ im_bytes = f.read()
948
+ im_b64_encode = base64.b64encode(im_bytes)
949
+ im_b64 = im_b64_encode.decode("utf8")
950
+ return im_b64
951
+
952
+ def base64_to_bytes(im_b64):
953
+ im_b64_encode = im_b64.encode("utf-8")
954
+ im_bytes = base64.b64decode(im_b64_encode)
955
+ return im_bytes
956
+
957
+ def base64_to_PILImage(im_b64):
958
+ im_bytes = base64_to_bytes(im_b64)
959
+ pil_img = Image.open(io.BytesIO(im_bytes))
960
+ return pil_img
961
+
962
+ class API_Starter:
963
+ def __init__(self):
964
+ from flask import Flask, request, jsonify, make_response
965
+ from flask_cors import CORS, cross_origin
966
+ import logging
967
+
968
+ app = Flask(__name__)
969
+ app.logger.setLevel(logging.ERROR)
970
+ CORS(app, supports_credentials=True, resources={r"/*": {"origins": "*"}})
971
+
972
+ @app.route('/imgCLeaner', methods=['GET', 'POST'])
973
+ @cross_origin()
974
+ def processAssist():
975
+ if request.method == 'GET':
976
+ ret_json = {'code': -1, 'reason':'no support to get'}
977
+ elif request.method == 'POST':
978
+ request_data = request.data.decode('utf-8')
979
+ data = json.loads(request_data)
980
+ result = self.handle_data(data)
981
+ ret_json = {'code': 0, 'result':result}
982
+ return jsonify(ret_json)
983
+
984
+ self.app = app
985
+ now_time = datetime.now().strftime('%Y%m%d_%H%M%S')
986
+ logger.add(f'./logs/logger_[{args.port}]_{now_time}.log')
987
+ signal.signal(signal.SIGINT, self.signal_handler)
988
+
989
+ def handle_data(self, data):
990
+ im_b64 = data['img']
991
+ img = base64_to_PILImage(im_b64)
992
+ results = run_anything_task(input_image = img,
993
+ text_prompt = data['remove_texts'],
994
+ task_type = 'remove',
995
+ inpaint_prompt = '',
996
+ box_threshold = 0.3,
997
+ text_threshold = 0.25,
998
+ iou_threshold = 0.8,
999
+ inpaint_mode = "merge",
1000
+ mask_source_radio = "type what to detect below",
1001
+ remove_mode = "rectangle", # ["segment", "rectangle"]
1002
+ remove_mask_extend = "10",
1003
+ num_relation = 5,
1004
+ kosmos_input = None,
1005
+ cleaner_size_limit = -1,
1006
+ )
1007
+ output_images = results[0]
1008
+ ret_json_images = []
1009
+ file_temp = int(time.time())
1010
+ count = 0
1011
+ for image_pil in output_images:
1012
+ try:
1013
+ img_format = image_pil.format.lower()
1014
+ except Exception as e:
1015
+ img_format = 'png'
1016
+ image_path = os.path.join(output_dir, f"api_images_{file_temp}_{count}.{img_format}")
1017
+ count += 1
1018
+ try:
1019
+ image_pil.save(image_path)
1020
+ except Exception as e:
1021
+ Image.fromarray(image_pil).save(image_path)
1022
+ im_b64 = imgFile_to_base64(image_path)
1023
+ ret_json_images.append(im_b64)
1024
+ os.remove(image_path)
1025
+ data = {
1026
+ 'imgs': ret_json_images,
1027
+ }
1028
+ return data
1029
+
1030
+ def signal_handler(self, signal, frame):
1031
+ print('\nSignal Catched! You have just type Ctrl+C!')
1032
+ sys.exit(0)
1033
+
1034
+ def run(self):
1035
+ from gevent import pywsgi
1036
+ logger.info(f'\nargs={args}\n')
1037
+ computer_info()
1038
+ server = pywsgi.WSGIServer(('0.0.0.0', args.port), self.app)
1039
+ server.serve_forever()
1040
+
1041
+ def main_api(args):
1042
+ if args.port == 0:
1043
+ print('Please give valid port!')
1044
+ else:
1045
+ api_starter = API_Starter()
1046
+ api_starter.run()
1047
+
1048
+ if __name__ == "__main__":
1049
+ parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
1050
+ parser.add_argument("--debug", action="store_true", help="using debug mode")
1051
+ parser.add_argument("--share", action="store_true", help="share the app")
1052
+ parser.add_argument("--port", "-p", type=int, default=7860, help="port")
1053
+ args, _ = parser.parse_known_args()
1054
+ print(f'args = {args}')
1055
+
1056
+ if os.environ.get('IS_MY_DEBUG') is None:
1057
+ os.system("pip list")
1058
+
1059
+ device = set_device()
1060
+ if device == 'cpu':
1061
+ kosmos_enable = False
1062
+
1063
+ if kosmos_enable:
1064
+ kosmos_model, kosmos_processor = load_kosmos_model(device)
1065
+
1066
+ if groundingdino_enable:
1067
+ groundingdino_model = load_groundingdino_model('cpu')
1068
+
1069
+ if sam_enable:
1070
+ load_sam_model(device)
1071
+
1072
+ if inpainting_enable:
1073
+ load_sd_model(device)
1074
+
1075
+ if lama_cleaner_enable:
1076
+ load_lama_cleaner_model(device)
1077
+
1078
+ if ram_enable:
1079
+ load_ram_model(device)
1080
+
1081
+ if os.environ.get('IS_MY_DEBUG') is None:
1082
+ os.system("pip list")
1083
+
1084
+ # print(f'groundingdino_model__{get_model_device(groundingdino_model)}')
1085
+ # print(f'sam_model__{get_model_device(sam_model)}')
1086
+ # print(f'sd_model__{get_model_device(sd_model)}')
1087
+ # print(f'lama_cleaner_model__{get_model_device(lama_cleaner_model)}')
1088
+ # print(f'ram_model__{get_model_device(ram_model)}')
1089
+ # print(f'kosmos_model__{get_model_device(kosmos_model)}')
1090
+
1091
+ if os.environ.get('IS_MY_DEBUG') is None:
1092
+ # Provide gradio services
1093
+ main_gradio(args)
1094
+ else:
1095
+ if 0 == 0:
1096
+ # Provide API services
1097
+ main_api(args)
1098
+ else:
1099
+ # Provide gradio services
1100
+ main_gradio(args)
1101
+
1102
+
1103
+
1104
 
requirements.txt CHANGED
@@ -15,14 +15,10 @@ setuptools
15
  supervision
16
  termcolor
17
  timm
18
- # torch
19
- # torchvision
20
  torch==2.0.0
21
  torchvision==0.15.1
22
 
23
- # torch==2.1.0
24
- # torchvision==0.16.0
25
-
26
  yapf
27
  numba
28
  scipy
 
15
  supervision
16
  termcolor
17
  timm
 
 
18
  torch==2.0.0
19
  torchvision==0.15.1
20
 
21
+ gevent
 
 
22
  yapf
23
  numba
24
  scipy