ttengwang commited on
Commit
5c74464
1 Parent(s): 7d57bb5

fix bugs of example images and api keys

Browse files
Image/demo1.svg CHANGED
Image/demo2.svg CHANGED
app.py CHANGED
@@ -40,16 +40,16 @@ description = """Gradio demo for Caption Anything, image to dense captioning gen
40
  """
41
 
42
  examples = [
 
43
  ["test_img/img2.jpg"],
44
  ["test_img/img5.jpg"],
45
  ["test_img/img12.jpg"],
46
  ["test_img/img14.jpg"],
 
 
47
  ]
48
 
49
  args = parse_augment()
50
- args.captioner = 'blip2'
51
- args.seg_crop_mode = 'wo_bg'
52
- args.regular_box = True
53
  # args.device = 'cuda:5'
54
  # args.disable_gpt = False
55
  # args.enable_reduce_tokens = True
@@ -57,9 +57,10 @@ args.regular_box = True
57
  model = CaptionAnything(args)
58
 
59
  def init_openai_api_key(api_key):
60
- os.environ['OPENAI_API_KEY'] = api_key
61
- model.init_refiner()
62
-
 
63
 
64
  def get_prompt(chat_input, click_state):
65
  points = click_state[0]
@@ -78,7 +79,7 @@ def get_prompt(chat_input, click_state):
78
  return prompt
79
 
80
  def chat_with_points(chat_input, click_state, state):
81
- if not hasattr(model, "text_refiner"):
82
  response = "Text refiner is not initilzed, please input openai api key."
83
  state = state + [(chat_input, response)]
84
  return state, state
@@ -132,7 +133,7 @@ def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality
132
  image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
133
 
134
  yield state, state, click_state, chat_input, image_input, wiki
135
- if not args.disable_gpt and hasattr(model, "text_refiner"):
136
  refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
137
  # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
138
  new_cap = refined_caption['caption']
@@ -143,10 +144,16 @@ def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality
143
  def upload_callback(image_input, state):
144
  state = [] + [('Image size: ' + str(image_input.size), None)]
145
  click_state = [[], [], []]
 
 
 
 
 
 
146
  model.segmenter.image = None
147
  model.segmenter.image_embedding = None
148
  model.segmenter.set_image(image_input)
149
- return state, image_input, click_state
150
 
151
  with gr.Blocks(
152
  css='''
@@ -163,55 +170,62 @@ with gr.Blocks(
163
 
164
  with gr.Row():
165
  with gr.Column(scale=1.0):
166
- image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
167
- with gr.Row(scale=1.0):
168
- point_prompt = gr.Radio(
169
- choices=["Positive", "Negative"],
170
- value="Positive",
171
- label="Point Prompt",
172
- interactive=True)
173
- clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
174
- clear_button_image = gr.Button(value="Clear Image", interactive=True)
175
- with gr.Row(scale=1.0):
176
- language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
177
-
178
- sentiment = gr.Radio(
179
- choices=["Positive", "Natural", "Negative"],
180
- value="Natural",
181
- label="Sentiment",
182
- interactive=True,
183
- )
184
- with gr.Row(scale=1.0):
185
- factuality = gr.Radio(
186
- choices=["Factual", "Imagination"],
187
- value="Factual",
188
- label="Factuality",
189
- interactive=True,
190
- )
191
- length = gr.Slider(
192
- minimum=10,
193
- maximum=80,
194
- value=10,
195
- step=1,
196
- interactive=True,
197
- label="Length",
198
- )
 
 
 
199
 
200
  with gr.Column(scale=0.5):
201
  openai_api_key = gr.Textbox(
202
- placeholder="Input your openAI API key and press Enter",
203
  show_label=False,
204
  label = "OpenAI API Key",
205
  lines=1,
206
  type="password"
207
  )
208
- openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
209
- wiki_output = gr.Textbox(lines=6, label="Wiki")
210
- chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=450,scale=0.5)
211
- chat_input = gr.Textbox(lines=1, label="Chat Input")
212
- with gr.Row():
213
- clear_button_text = gr.Button(value="Clear Text", interactive=True)
214
- submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
 
 
 
 
215
  clear_button_clike.click(
216
  lambda x: ([[], [], []], x, ""),
217
  [origin_image],
@@ -220,9 +234,9 @@ with gr.Blocks(
220
  show_progress=False
221
  )
222
  clear_button_image.click(
223
- lambda: (None, [], [], [[], [], []], ""),
224
  [],
225
- [image_input, chatbot, state, click_state, wiki_output],
226
  queue=False,
227
  show_progress=False
228
  )
@@ -234,20 +248,25 @@ with gr.Blocks(
234
  show_progress=False
235
  )
236
  image_input.clear(
237
- lambda: (None, [], [], [[], [], []], ""),
238
  [],
239
- [image_input, chatbot, state, click_state, wiki_output],
240
  queue=False,
241
  show_progress=False
242
  )
243
 
244
- examples = gr.Examples(
 
 
 
 
245
  examples=examples,
246
- inputs=[image_input],
247
  )
248
 
249
- image_input.upload(upload_callback,[image_input, state], [state, origin_image, click_state])
250
  chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
 
251
 
252
  # select coordinate
253
  image_input.select(inference_seg_cap,
@@ -264,5 +283,5 @@ with gr.Blocks(
264
  outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
265
  show_progress=False, queue=True)
266
 
267
- iface.queue(concurrency_count=5, api_open=False, max_size=10)
268
- iface.launch(server_name="0.0.0.0", enable_queue=True)
 
40
  """
41
 
42
  examples = [
43
+ ["test_img/img35.webp"],
44
  ["test_img/img2.jpg"],
45
  ["test_img/img5.jpg"],
46
  ["test_img/img12.jpg"],
47
  ["test_img/img14.jpg"],
48
+ ["test_img/img0.png"],
49
+ ["test_img/img1.jpg"],
50
  ]
51
 
52
  args = parse_augment()
 
 
 
53
  # args.device = 'cuda:5'
54
  # args.disable_gpt = False
55
  # args.enable_reduce_tokens = True
 
57
  model = CaptionAnything(args)
58
 
59
  def init_openai_api_key(api_key):
60
+ # os.environ['OPENAI_API_KEY'] = api_key
61
+ model.init_refiner(api_key)
62
+ openai_available = model.text_refiner is not None
63
+ return gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = True), gr.update(visible = True)
64
 
65
  def get_prompt(chat_input, click_state):
66
  points = click_state[0]
 
79
  return prompt
80
 
81
  def chat_with_points(chat_input, click_state, state):
82
+ if model.text_refiner is None:
83
  response = "Text refiner is not initilzed, please input openai api key."
84
  state = state + [(chat_input, response)]
85
  return state, state
 
133
  image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
134
 
135
  yield state, state, click_state, chat_input, image_input, wiki
136
+ if not args.disable_gpt and model.text_refiner:
137
  refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
138
  # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
139
  new_cap = refined_caption['caption']
 
144
  def upload_callback(image_input, state):
145
  state = [] + [('Image size: ' + str(image_input.size), None)]
146
  click_state = [[], [], []]
147
+ res = 1024
148
+ width, height = image_input.size
149
+ ratio = min(1.0 * res / max(width, height), 1.0)
150
+ if ratio < 1.0:
151
+ image_input = image_input.resize((int(width * ratio), int(height * ratio)))
152
+ print('Scaling input image to {}'.format(image_input.size))
153
  model.segmenter.image = None
154
  model.segmenter.image_embedding = None
155
  model.segmenter.set_image(image_input)
156
+ return state, image_input, click_state, image_input
157
 
158
  with gr.Blocks(
159
  css='''
 
170
 
171
  with gr.Row():
172
  with gr.Column(scale=1.0):
173
+ with gr.Column(visible=False) as modules_not_need_gpt:
174
+ image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
175
+ example_image = gr.Image(type="pil", interactive=False, visible=False)
176
+ with gr.Row(scale=1.0):
177
+ point_prompt = gr.Radio(
178
+ choices=["Positive", "Negative"],
179
+ value="Positive",
180
+ label="Point Prompt",
181
+ interactive=True)
182
+ clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
183
+ clear_button_image = gr.Button(value="Clear Image", interactive=True)
184
+ with gr.Column(visible=False) as modules_need_gpt:
185
+ with gr.Row(scale=1.0):
186
+ language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
187
+
188
+ sentiment = gr.Radio(
189
+ choices=["Positive", "Natural", "Negative"],
190
+ value="Natural",
191
+ label="Sentiment",
192
+ interactive=True,
193
+ )
194
+ with gr.Row(scale=1.0):
195
+ factuality = gr.Radio(
196
+ choices=["Factual", "Imagination"],
197
+ value="Factual",
198
+ label="Factuality",
199
+ interactive=True,
200
+ )
201
+ length = gr.Slider(
202
+ minimum=10,
203
+ maximum=80,
204
+ value=10,
205
+ step=1,
206
+ interactive=True,
207
+ label="Length",
208
+ )
209
 
210
  with gr.Column(scale=0.5):
211
  openai_api_key = gr.Textbox(
212
+ placeholder="Input openAI API key and press Enter (Input blank will disable GPT)",
213
  show_label=False,
214
  label = "OpenAI API Key",
215
  lines=1,
216
  type="password"
217
  )
218
+ with gr.Column(visible=False) as modules_need_gpt2:
219
+ wiki_output = gr.Textbox(lines=6, label="Wiki")
220
+ with gr.Column(visible=False) as modules_not_need_gpt2:
221
+ chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=450,scale=0.5)
222
+ with gr.Column(visible=False) as modules_need_gpt3:
223
+ chat_input = gr.Textbox(lines=1, label="Chat Input")
224
+ with gr.Row():
225
+ clear_button_text = gr.Button(value="Clear Text", interactive=True)
226
+ submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
227
+
228
+ openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2])
229
  clear_button_clike.click(
230
  lambda x: ([[], [], []], x, ""),
231
  [origin_image],
 
234
  show_progress=False
235
  )
236
  clear_button_image.click(
237
+ lambda: (None, [], [], [[], [], []], "", ""),
238
  [],
239
+ [image_input, chatbot, state, click_state, wiki_output, origin_image],
240
  queue=False,
241
  show_progress=False
242
  )
 
248
  show_progress=False
249
  )
250
  image_input.clear(
251
+ lambda: (None, [], [], [[], [], []], "", ""),
252
  [],
253
+ [image_input, chatbot, state, click_state, wiki_output, origin_image],
254
  queue=False,
255
  show_progress=False
256
  )
257
 
258
+ def example_callback(x):
259
+ model.image_embedding = None
260
+ return x
261
+
262
+ gr.Examples(
263
  examples=examples,
264
+ inputs=[example_image],
265
  )
266
 
267
+ image_input.upload(upload_callback,[image_input, state], [state, origin_image, click_state, image_input])
268
  chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
269
+ example_image.change(upload_callback,[example_image, state], [state, origin_image, click_state, image_input])
270
 
271
  # select coordinate
272
  image_input.select(inference_seg_cap,
 
283
  outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
284
  show_progress=False, queue=True)
285
 
286
+ iface.queue(concurrency_count=1, api_open=False, max_size=10)
287
+ iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)
app_old.py CHANGED
@@ -98,9 +98,9 @@ def chat_with_points(chat_input, click_state, state):
98
  return state, state
99
 
100
  def init_openai_api_key(api_key):
101
- os.environ['OPENAI_API_KEY'] = api_key
102
  global model
103
- model = CaptionAnything(args)
104
 
105
  css='''
106
  #image_upload{min-height:200px}
 
98
  return state, state
99
 
100
  def init_openai_api_key(api_key):
101
+ # os.environ['OPENAI_API_KEY'] = api_key
102
  global model
103
+ model = CaptionAnything(args, api_key)
104
 
105
  css='''
106
  #image_upload{min-height:200px}
caption_anything.py CHANGED
@@ -8,18 +8,22 @@ import time
8
  from PIL import Image
9
 
10
  class CaptionAnything():
11
- def __init__(self, args):
12
  self.args = args
13
  self.captioner = build_captioner(args.captioner, args.device, args)
14
  self.segmenter = build_segmenter(args.segmenter, args.device, args)
 
15
  if not args.disable_gpt:
16
- self.init_refiner()
17
 
18
-
19
- def init_refiner(self):
20
- if os.environ.get('OPENAI_API_KEY', None):
21
- self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args)
22
-
 
 
 
23
  def inference(self, image, prompt, controls, disable_gpt=False):
24
  # segment with prompt
25
  print("CA prompt: ", prompt, "CA controls",controls)
@@ -35,14 +39,14 @@ class CaptionAnything():
35
  print("seg_mask.shape: ", seg_mask.shape)
36
  # captioning with mask
37
  if self.args.enable_reduce_tokens:
38
- caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box)
39
  else:
40
- caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box)
41
  # refining with TextRefiner
42
  context_captions = []
43
  if self.args.context_captions:
44
  context_captions.append(self.captioner.inference(image))
45
- if not disable_gpt and hasattr(self, "text_refiner"):
46
  refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
47
  else:
48
  refined_caption = {'raw_caption': caption}
@@ -54,14 +58,14 @@ class CaptionAnything():
54
 
55
  def parse_augment():
56
  parser = argparse.ArgumentParser()
57
- parser.add_argument('--captioner', type=str, default="blip")
58
  parser.add_argument('--segmenter', type=str, default="base")
59
  parser.add_argument('--text_refiner', type=str, default="base")
60
  parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
61
- parser.add_argument('--seg_crop_mode', type=str, default="w_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
62
  parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
63
  parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption (TODO)")
64
- parser.add_argument('--regular_box', action="store_true", default = False, help="crop image with a regular box")
65
  parser.add_argument('--device', type=str, default="cuda:0")
66
  parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
67
  parser.add_argument('--debug', action="store_true")
@@ -101,7 +105,7 @@ if __name__ == "__main__":
101
  "language": "English",
102
  }
103
 
104
- model = CaptionAnything(args)
105
  for prompt in prompts:
106
  print('*'*30)
107
  print('Image path: ', image_path)
 
8
  from PIL import Image
9
 
10
  class CaptionAnything():
11
+ def __init__(self, args, api_key=""):
12
  self.args = args
13
  self.captioner = build_captioner(args.captioner, args.device, args)
14
  self.segmenter = build_segmenter(args.segmenter, args.device, args)
15
+ self.text_refiner = None
16
  if not args.disable_gpt:
17
+ self.init_refiner(api_key)
18
 
19
+ def init_refiner(self, api_key):
20
+ try:
21
+ self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key)
22
+ self.text_refiner.llm('hi') # test
23
+ except:
24
+ self.text_refiner = None
25
+ print('Openai api key is NOT given')
26
+
27
  def inference(self, image, prompt, controls, disable_gpt=False):
28
  # segment with prompt
29
  print("CA prompt: ", prompt, "CA controls",controls)
 
39
  print("seg_mask.shape: ", seg_mask.shape)
40
  # captioning with mask
41
  if self.args.enable_reduce_tokens:
42
+ caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
43
  else:
44
+ caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
45
  # refining with TextRefiner
46
  context_captions = []
47
  if self.args.context_captions:
48
  context_captions.append(self.captioner.inference(image))
49
+ if not disable_gpt and self.text_refiner is not None:
50
  refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
51
  else:
52
  refined_caption = {'raw_caption': caption}
 
58
 
59
  def parse_augment():
60
  parser = argparse.ArgumentParser()
61
+ parser.add_argument('--captioner', type=str, default="blip2")
62
  parser.add_argument('--segmenter', type=str, default="base")
63
  parser.add_argument('--text_refiner', type=str, default="base")
64
  parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
65
+ parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
66
  parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
67
  parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption (TODO)")
68
+ parser.add_argument('--disable_regular_box', action="store_true", default = False, help="crop image with a regular box")
69
  parser.add_argument('--device', type=str, default="cuda:0")
70
  parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
71
  parser.add_argument('--debug', action="store_true")
 
105
  "language": "English",
106
  }
107
 
108
+ model = CaptionAnything(args, os.environ['OPENAI_API_KEY'])
109
  for prompt in prompts:
110
  print('*'*30)
111
  print('Image path: ', image_path)
captioner/base_captioner.py CHANGED
@@ -135,7 +135,7 @@ class BaseCaptioner:
135
  return caption, crop_save_path
136
 
137
 
138
- def inference_seg(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg", filter=False, regular_box = False):
139
  if type(image) == str:
140
  image = Image.open(image)
141
  if type(seg_mask) == str:
@@ -151,14 +151,14 @@ class BaseCaptioner:
151
  else:
152
  image = np.array(image)
153
 
154
- if regular_box:
155
- min_area_box = new_seg_to_box(seg_mask)
156
- else:
157
  min_area_box = seg_to_box(seg_mask)
 
 
158
  return self.inference_box(image, min_area_box, filter)
159
 
160
 
161
- def generate_seg_cropped_image(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg", regular_box = False):
162
  if type(image) == str:
163
  image = Image.open(image)
164
  if type(seg_mask) == str:
@@ -173,10 +173,10 @@ class BaseCaptioner:
173
  else:
174
  image = np.array(image)
175
 
176
- if regular_box:
177
- box = new_seg_to_box(seg_mask)
178
- else:
179
  box = seg_to_box(seg_mask)
 
 
180
 
181
  if np.array(box).size == 4: # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
182
  size = max(image.shape[0], image.shape[1])
 
135
  return caption, crop_save_path
136
 
137
 
138
+ def inference_seg(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg", filter=False, disable_regular_box = False):
139
  if type(image) == str:
140
  image = Image.open(image)
141
  if type(seg_mask) == str:
 
151
  else:
152
  image = np.array(image)
153
 
154
+ if disable_regular_box:
 
 
155
  min_area_box = seg_to_box(seg_mask)
156
+ else:
157
+ min_area_box = new_seg_to_box(seg_mask)
158
  return self.inference_box(image, min_area_box, filter)
159
 
160
 
161
+ def generate_seg_cropped_image(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg", disable_regular_box = False):
162
  if type(image) == str:
163
  image = Image.open(image)
164
  if type(seg_mask) == str:
 
173
  else:
174
  image = np.array(image)
175
 
176
+ if disable_regular_box:
 
 
177
  box = seg_to_box(seg_mask)
178
+ else:
179
+ box = new_seg_to_box(seg_mask)
180
 
181
  if np.array(box).size == 4: # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
182
  size = max(image.shape[0], image.shape[1])
captioner/blip.py CHANGED
@@ -25,15 +25,15 @@ class BLIPCaptioner(BaseCaptioner):
25
  image = Image.open(image)
26
  inputs = self.processor(image, return_tensors="pt").to(self.device, self.torch_dtype)
27
  out = self.model.generate(**inputs, max_new_tokens=50)
28
- captions = self.processor.decode(out[0], skip_special_tokens=True)
29
  if self.enable_filter and filter:
30
  captions = self.filter_caption(image, captions)
31
  print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
32
  return captions
33
 
34
  @torch.no_grad()
35
- def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg", filter=False, regular_box = False):
36
- crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode, regular_box=regular_box)
37
  if type(image) == str: # input path
38
  image = Image.open(image)
39
  inputs = self.processor(image, return_tensors="pt")
@@ -45,7 +45,7 @@ class BLIPCaptioner(BaseCaptioner):
45
  seg_mask = seg_mask.float()
46
  pixel_masks = seg_mask.unsqueeze(0).to(self.device)
47
  out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
48
- captions = self.processor.decode(out[0], skip_special_tokens=True)
49
  if self.enable_filter and filter:
50
  captions = self.filter_caption(image, captions)
51
  print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
 
25
  image = Image.open(image)
26
  inputs = self.processor(image, return_tensors="pt").to(self.device, self.torch_dtype)
27
  out = self.model.generate(**inputs, max_new_tokens=50)
28
+ captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
29
  if self.enable_filter and filter:
30
  captions = self.filter_caption(image, captions)
31
  print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
32
  return captions
33
 
34
  @torch.no_grad()
35
+ def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg", filter=False, disable_regular_box = False):
36
+ crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode, disable_regular_box=disable_regular_box)
37
  if type(image) == str: # input path
38
  image = Image.open(image)
39
  inputs = self.processor(image, return_tensors="pt")
 
45
  seg_mask = seg_mask.float()
46
  pixel_masks = seg_mask.unsqueeze(0).to(self.device)
47
  out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
48
+ captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
49
  if self.enable_filter and filter:
50
  captions = self.filter_caption(image, captions)
51
  print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
captioner/blip2.py CHANGED
@@ -22,9 +22,10 @@ class BLIP2Captioner(BaseCaptioner):
22
  image = Image.open(image)
23
 
24
  if not self.dialogue:
25
- inputs = self.processor(image, text = 'Ignore the black background! This is a photo of ', return_tensors="pt").to(self.device, self.torch_dtype)
 
26
  out = self.model.generate(**inputs, max_new_tokens=50)
27
- captions = self.processor.decode(out[0], skip_special_tokens=True)
28
  if self.enable_filter and filter:
29
  captions = self.filter_caption(image, captions)
30
  print(f"\nProcessed ImageCaptioning by BLIP2Captioner, Output Text: {captions}")
 
22
  image = Image.open(image)
23
 
24
  if not self.dialogue:
25
+ text_prompt = 'Context: ignore the white part in this image. Question: describe this image. Answer:'
26
+ inputs = self.processor(image, text = text_prompt, return_tensors="pt").to(self.device, self.torch_dtype)
27
  out = self.model.generate(**inputs, max_new_tokens=50)
28
+ captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
29
  if self.enable_filter and filter:
30
  captions = self.filter_caption(image, captions)
31
  print(f"\nProcessed ImageCaptioning by BLIP2Captioner, Output Text: {captions}")
captioner/git.py CHANGED
@@ -22,15 +22,15 @@ class GITCaptioner(BaseCaptioner):
22
  image = Image.open(image)
23
  pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, self.torch_dtype)
24
  generated_ids = self.model.generate(pixel_values=pixel_values, max_new_tokens=50)
25
- generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
26
  if self.enable_filter and filter:
27
  captions = self.filter_caption(image, captions)
28
  print(f"\nProcessed ImageCaptioning by GITCaptioner, Output Text: {generated_caption}")
29
  return generated_caption
30
 
31
  @torch.no_grad()
32
- def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg", filter=False, regular_box = False):
33
- crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode, regular_box=regular_box)
34
  if type(image) == str: # input path
35
  image = Image.open(image)
36
  inputs = self.processor(images=image, return_tensors="pt")
@@ -42,7 +42,7 @@ class GITCaptioner(BaseCaptioner):
42
  seg_mask = seg_mask.float()
43
  pixel_masks = seg_mask.unsqueeze(0).to(self.device)
44
  out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
45
- captions = self.processor.decode(out[0], skip_special_tokens=True)
46
  if self.enable_filter and filter:
47
  captions = self.filter_caption(image, captions)
48
  print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
 
22
  image = Image.open(image)
23
  pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, self.torch_dtype)
24
  generated_ids = self.model.generate(pixel_values=pixel_values, max_new_tokens=50)
25
+ generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
26
  if self.enable_filter and filter:
27
  captions = self.filter_caption(image, captions)
28
  print(f"\nProcessed ImageCaptioning by GITCaptioner, Output Text: {generated_caption}")
29
  return generated_caption
30
 
31
  @torch.no_grad()
32
+ def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg", filter=False, disable_regular_box = False):
33
+ crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode, disable_regular_box=disable_regular_box)
34
  if type(image) == str: # input path
35
  image = Image.open(image)
36
  inputs = self.processor(images=image, return_tensors="pt")
 
42
  seg_mask = seg_mask.float()
43
  pixel_masks = seg_mask.unsqueeze(0).to(self.device)
44
  out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
45
+ captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
46
  if self.enable_filter and filter:
47
  captions = self.filter_caption(image, captions)
48
  print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
image_editing_utils.py CHANGED
@@ -35,7 +35,8 @@ def create_bubble_frame(image, text, point, font_path='DejaVuSansCondensed-Bold.
35
 
36
  # Wrap the text to fit within the max_text_width
37
  lines = wrap_text(text, font, max_text_width)
38
- text_width, text_height = font.getsize(lines[0])
 
39
  text_height = text_height * len(lines)
40
 
41
  # Define bubble frame dimensions
@@ -48,7 +49,7 @@ def create_bubble_frame(image, text, point, font_path='DejaVuSansCondensed-Bold.
48
 
49
  # Draw the bubble frame on the new image
50
  draw = ImageDraw.Draw(bubble)
51
- draw.rectangle([(0, 0), (bubble_width - 1, bubble_height - 1)], fill=(255, 255, 255, 0), outline=(255, 255, 255, 0), width=2)
52
 
53
  # Draw the wrapped text line by line
54
  y_text = padding
 
35
 
36
  # Wrap the text to fit within the max_text_width
37
  lines = wrap_text(text, font, max_text_width)
38
+ text_width = max([font.getsize(line)[0] for line in lines])
39
+ _, text_height = font.getsize(lines[0])
40
  text_height = text_height * len(lines)
41
 
42
  # Define bubble frame dimensions
 
49
 
50
  # Draw the bubble frame on the new image
51
  draw = ImageDraw.Draw(bubble)
52
+ # draw.rectangle([(0, 0), (bubble_width - 1, bubble_height - 1)], fill=(255, 255, 255, 0), outline=(255, 255, 255, 0), width=2)
53
 
54
  # Draw the wrapped text line by line
55
  y_text = padding
segmenter/base_segmenter.py CHANGED
@@ -46,7 +46,7 @@ class BaseSegmenter:
46
  new_masks = np.concatenate([mask["segmentation"][np.newaxis,:] for mask in masks])
47
  return new_masks
48
  else:
49
- if not self.reuse_feature:
50
  self.set_image(image)
51
  self.predictor.set_image(self.image)
52
  else:
 
46
  new_masks = np.concatenate([mask["segmentation"][np.newaxis,:] for mask in masks])
47
  return new_masks
48
  else:
49
+ if not self.reuse_feature or self.image_embedding is None:
50
  self.set_image(image)
51
  self.predictor.set_image(self.image)
52
  else:
text_refiner/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
  from text_refiner.text_refiner import TextRefiner
2
 
3
 
4
- def build_text_refiner(type, device, args=None):
5
  if type == 'base':
6
- return TextRefiner(device)
 
1
  from text_refiner.text_refiner import TextRefiner
2
 
3
 
4
+ def build_text_refiner(type, device, args=None, api_key=""):
5
  if type == 'base':
6
+ return TextRefiner(device, api_key)
text_refiner/text_refiner.py CHANGED
@@ -5,12 +5,9 @@ from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration,
5
  import pdb
6
 
7
  class TextRefiner:
8
- def __init__(self, device):
9
  print(f"Initializing TextRefiner to {device}")
10
- try:
11
- self.llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0)
12
- except:
13
- print('Openai api key is NOT given')
14
  self.prompt_tag = {
15
  "imagination": {"True": "could",
16
  "False": "could not"}
 
5
  import pdb
6
 
7
  class TextRefiner:
8
+ def __init__(self, device, api_key=""):
9
  print(f"Initializing TextRefiner to {device}")
10
+ self.llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key)
 
 
 
11
  self.prompt_tag = {
12
  "imagination": {"True": "could",
13
  "False": "could not"}
tools.py CHANGED
@@ -1,7 +1,9 @@
1
  import cv2
 
2
  import numpy as np
3
  from PIL import Image
4
  import copy
 
5
 
6
 
7
  def colormap(rgb=True):
@@ -100,16 +102,6 @@ color_list = colormap()
100
  color_list = color_list.astype('uint8').tolist()
101
 
102
 
103
- def gauss_filter(kernel_size, sigma):
104
- max_idx = kernel_size // 2
105
- idx = np.linspace(-max_idx, max_idx, kernel_size)
106
- Y, X = np.meshgrid(idx, idx)
107
- gauss_filter = np.exp(-(X**2 + Y**2) / (2*sigma**2))
108
- gauss_filter /= np.sum(np.sum(gauss_filter))
109
-
110
- return gauss_filter
111
-
112
-
113
  def vis_add_mask(image, mask, color, alpha, kernel_size):
114
  color = np.array(color)
115
  mask = mask.astype('float').copy()
@@ -129,6 +121,23 @@ def vis_add_mask_wo_blur(image, mask, color, alpha):
129
  return image
130
 
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1):
133
  """
134
  Input:
@@ -146,11 +155,7 @@ def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_
146
  assert input_image.shape[:2] == input_mask.shape, 'different shape'
147
  assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
148
 
149
- width, height = input_image.shape[0], input_image.shape[1]
150
- res = 1024
151
- ratio = min(1.0 * res / max(width, height), 1.0)
152
- input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
153
- input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
154
  # 0: background, 1: foreground
155
  input_mask[input_mask>0] = 255
156
 
@@ -163,15 +168,120 @@ def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_
163
  kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
164
  contour_mask = cv2.dilate(contour_mask, kernel)
165
  painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
166
- painted_image = cv2.resize(painted_image, (height, width))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  return painted_image
168
 
169
 
170
  if __name__ == '__main__':
171
 
172
  background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
173
- background_blur_radius = 35 # radius of background blur, must be odd number
174
- contour_width = 7 # contour width, must be odd number
175
  contour_color = 3 # id in color map, 0: black, 1: white, >1: others
176
  contour_alpha = 1 # transparency of background, 0: no contour highlighted
177
 
@@ -180,8 +290,54 @@ if __name__ == '__main__':
180
  input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
181
 
182
  # paint
183
- painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  # save
186
- painted_image = Image.fromarray(painted_image)
187
- painted_image.save('./test_img/painter_output_image.png')
 
 
 
 
 
 
 
 
 
 
1
  import cv2
2
+ import torch
3
  import numpy as np
4
  from PIL import Image
5
  import copy
6
+ import time
7
 
8
 
9
  def colormap(rgb=True):
 
102
  color_list = color_list.astype('uint8').tolist()
103
 
104
 
 
 
 
 
 
 
 
 
 
 
105
  def vis_add_mask(image, mask, color, alpha, kernel_size):
106
  color = np.array(color)
107
  mask = mask.astype('float').copy()
 
121
  return image
122
 
123
 
124
+ def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
125
+ background_color = np.array(background_color)
126
+ contour_color = np.array(contour_color)
127
+
128
+ # background_mask = 1 - background_mask
129
+ # contour_mask = 1 - contour_mask
130
+
131
+ for i in range(3):
132
+ image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
133
+ + background_color[i] * (background_alpha-background_mask*background_alpha)
134
+
135
+ image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
136
+ + contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
137
+
138
+ return image.astype('uint8')
139
+
140
+
141
  def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1):
142
  """
143
  Input:
 
155
  assert input_image.shape[:2] == input_mask.shape, 'different shape'
156
  assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
157
 
158
+
 
 
 
 
159
  # 0: background, 1: foreground
160
  input_mask[input_mask>0] = 255
161
 
 
168
  kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
169
  contour_mask = cv2.dilate(contour_mask, kernel)
170
  painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
171
+
172
+ # painted_image = background_dist_map
173
+
174
+ return painted_image
175
+
176
+
177
+ def mask_generator_00(mask, background_radius, contour_radius):
178
+ # no background width when '00'
179
+ # distance map
180
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
181
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
182
+ dist_map = dist_transform_fore - dist_transform_back
183
+ # ...:::!!!:::...
184
+ contour_radius += 2
185
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
186
+ contour_mask = contour_mask / np.max(contour_mask)
187
+ contour_mask[contour_mask>0.5] = 1.
188
+
189
+ return mask, contour_mask
190
+
191
+
192
+ def mask_generator_01(mask, background_radius, contour_radius):
193
+ # no background width when '00'
194
+ # distance map
195
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
196
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
197
+ dist_map = dist_transform_fore - dist_transform_back
198
+ # ...:::!!!:::...
199
+ contour_radius += 2
200
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
201
+ contour_mask = contour_mask / np.max(contour_mask)
202
+ return mask, contour_mask
203
+
204
+
205
+ def mask_generator_10(mask, background_radius, contour_radius):
206
+ # distance map
207
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
208
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
209
+ dist_map = dist_transform_fore - dist_transform_back
210
+ # .....:::::!!!!!
211
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
212
+ background_mask = (background_mask - np.min(background_mask))
213
+ background_mask = background_mask / np.max(background_mask)
214
+ # ...:::!!!:::...
215
+ contour_radius += 2
216
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
217
+ contour_mask = contour_mask / np.max(contour_mask)
218
+ contour_mask[contour_mask>0.5] = 1.
219
+ return background_mask, contour_mask
220
+
221
+
222
+ def mask_generator_11(mask, background_radius, contour_radius):
223
+ # distance map
224
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
225
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
226
+ dist_map = dist_transform_fore - dist_transform_back
227
+ # .....:::::!!!!!
228
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
229
+ background_mask = (background_mask - np.min(background_mask))
230
+ background_mask = background_mask / np.max(background_mask)
231
+ # ...:::!!!:::...
232
+ contour_radius += 2
233
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
234
+ contour_mask = contour_mask / np.max(contour_mask)
235
+ return background_mask, contour_mask
236
+
237
+
238
+ def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
239
+ """
240
+ Input:
241
+ input_image: numpy array
242
+ input_mask: numpy array
243
+ background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
244
+ background_blur_radius: radius of background blur, must be odd number
245
+ contour_width: width of mask contour, must be odd number
246
+ contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
247
+ contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
248
+ mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
249
+
250
+ Output:
251
+ painted_image: numpy array
252
+ """
253
+ assert input_image.shape[:2] == input_mask.shape, 'different shape'
254
+ assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
255
+ assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
256
+
257
+ # downsample input image and mask
258
+ width, height = input_image.shape[0], input_image.shape[1]
259
+ res = 1024
260
+ ratio = min(1.0 * res / max(width, height), 1.0)
261
+ input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
262
+ input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
263
+
264
+ # 0: background, 1: foreground
265
+ msk = np.clip(input_mask, 0, 1)
266
+
267
+ # generate masks for background and contour pixels
268
+ background_radius = (background_blur_radius - 1) // 2
269
+ contour_radius = (contour_width - 1) // 2
270
+ generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
271
+ background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
272
+
273
+ # paint
274
+ painted_image = vis_add_mask_wo_gaussian\
275
+ (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
276
+
277
  return painted_image
278
 
279
 
280
  if __name__ == '__main__':
281
 
282
  background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
283
+ background_blur_radius = 31 # radius of background blur, must be odd number
284
+ contour_width = 11 # contour width, must be odd number
285
  contour_color = 3 # id in color map, 0: black, 1: white, >1: others
286
  contour_alpha = 1 # transparency of background, 0: no contour highlighted
287
 
 
290
  input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
291
 
292
  # paint
293
+ overall_time_1 = 0
294
+ overall_time_2 = 0
295
+ overall_time_3 = 0
296
+ overall_time_4 = 0
297
+ overall_time_5 = 0
298
+
299
+ for i in range(50):
300
+ t2 = time.time()
301
+ painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
302
+ e2 = time.time()
303
+
304
+ t3 = time.time()
305
+ painted_image_10 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
306
+ e3 = time.time()
307
+
308
+ t1 = time.time()
309
+ painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
310
+ e1 = time.time()
311
+
312
+ t4 = time.time()
313
+ painted_image_01 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
314
+ e4 = time.time()
315
+
316
+ t5 = time.time()
317
+ painted_image_11 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
318
+ e5 = time.time()
319
+
320
+ overall_time_1 += (e1 - t1)
321
+ overall_time_2 += (e2 - t2)
322
+ overall_time_3 += (e3 - t3)
323
+ overall_time_4 += (e4 - t4)
324
+ overall_time_5 += (e5 - t5)
325
+
326
+ print(f'average time w gaussian: {overall_time_1/50}')
327
+ print(f'average time w/o gaussian00: {overall_time_2/50}')
328
+ print(f'average time w/o gaussian10: {overall_time_3/50}')
329
+ print(f'average time w/o gaussian01: {overall_time_4/50}')
330
+ print(f'average time w/o gaussian11: {overall_time_5/50}')
331
 
332
  # save
333
+ painted_image_00 = Image.fromarray(painted_image_00)
334
+ painted_image_00.save('./test_img/painter_output_image_00.png')
335
+
336
+ painted_image_10 = Image.fromarray(painted_image_10)
337
+ painted_image_10.save('./test_img/painter_output_image_10.png')
338
+
339
+ painted_image_01 = Image.fromarray(painted_image_01)
340
+ painted_image_01.save('./test_img/painter_output_image_01.png')
341
+
342
+ painted_image_11 = Image.fromarray(painted_image_11)
343
+ painted_image_11.save('./test_img/painter_output_image_11.png')