liuhaotian commited on
Commit
d933b45
β€’
1 Parent(s): e6da15b
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -14,6 +14,9 @@ import gc
14
  from gradio import processing_utils
15
  from typing import Optional
16
 
 
 
 
17
  from huggingface_hub import hf_hub_download
18
  hf_hub_download = partial(hf_hub_download, library_name="gligen_demo")
19
 
@@ -56,10 +59,13 @@ class Instance:
56
  )
57
  self.capacity = capacity
58
 
59
- def get_model(self, model_type):
 
 
 
60
  if model_type in self.loaded_model_list:
61
  self.counter[model_type] += 1
62
- print(self.counter)
63
  return self.loaded_model_list[model_type]
64
 
65
  if self.capacity == len(self.loaded_model_list):
@@ -71,7 +77,7 @@ class Instance:
71
 
72
  self.counter[model_type] = 1
73
  self.loaded_model_list[model_type] = self._get_model(model_type)
74
- print(self.counter)
75
  return self.loaded_model_list[model_type]
76
 
77
  def _get_model(self, model_type):
@@ -218,16 +224,21 @@ def inference(task, language_instruction, grounding_instruction, inpainting_boxe
218
  inpainting_boxes_nodrop = inpainting_boxes_nodrop,
219
  )
220
 
 
 
 
 
 
221
  with torch.autocast(device_type='cuda', dtype=torch.float16):
222
  if task == 'Grounded Generation':
223
  if style_image == None:
224
- return grounded_generation_box(instance.get_model('base'), instruction, *args, **kwargs)
225
  else:
226
- return grounded_generation_box(instance.get_model('style'), instruction, *args, **kwargs)
227
  elif task == 'Grounded Inpainting':
228
  assert image is not None
229
  instruction['input_image'] = image.convert("RGB")
230
- return grounded_generation_box(instance.get_model('inpaint'), instruction, *args, **kwargs)
231
 
232
 
233
  def draw_box(boxes=[], texts=[], img=None):
@@ -264,7 +275,6 @@ def auto_append_grounding(language_instruction, grounding_texts):
264
  for grounding_text in grounding_texts:
265
  if grounding_text not in language_instruction and grounding_text != 'auto':
266
  language_instruction += "; " + grounding_text
267
- print(language_instruction)
268
  return language_instruction
269
 
270
 
 
14
  from gradio import processing_utils
15
  from typing import Optional
16
 
17
+ import warnings
18
+ warnings.filterwarnings("ignore")
19
+
20
  from huggingface_hub import hf_hub_download
21
  hf_hub_download = partial(hf_hub_download, library_name="gligen_demo")
22
 
 
59
  )
60
  self.capacity = capacity
61
 
62
+ def _log(self, batch_size, instruction, phrase_list):
63
+ print(dict(self.counter), f'samples: {batch_size}', f'prompt: {instruction}', f'phrases: {phrase_list}', sep=', ')
64
+
65
+ def get_model(self, model_type, batch_size, instruction, phrase_list):
66
  if model_type in self.loaded_model_list:
67
  self.counter[model_type] += 1
68
+ self._log(batch_size, instruction, phrase_list)
69
  return self.loaded_model_list[model_type]
70
 
71
  if self.capacity == len(self.loaded_model_list):
 
77
 
78
  self.counter[model_type] = 1
79
  self.loaded_model_list[model_type] = self._get_model(model_type)
80
+ self._log(batch_size, instruction, phrase_list)
81
  return self.loaded_model_list[model_type]
82
 
83
  def _get_model(self, model_type):
 
224
  inpainting_boxes_nodrop = inpainting_boxes_nodrop,
225
  )
226
 
227
+ get_model = partial(instance.get_model,
228
+ batch_size=batch_size,
229
+ instruction=language_instruction,
230
+ phrase_list=phrase_list)
231
+
232
  with torch.autocast(device_type='cuda', dtype=torch.float16):
233
  if task == 'Grounded Generation':
234
  if style_image == None:
235
+ return grounded_generation_box(get_model('base'), instruction, *args, **kwargs)
236
  else:
237
+ return grounded_generation_box(get_model('style'), instruction, *args, **kwargs)
238
  elif task == 'Grounded Inpainting':
239
  assert image is not None
240
  instruction['input_image'] = image.convert("RGB")
241
+ return grounded_generation_box(get_model('inpaint'), instruction, *args, **kwargs)
242
 
243
 
244
  def draw_box(boxes=[], texts=[], img=None):
 
275
  for grounding_text in grounding_texts:
276
  if grounding_text not in language_instruction and grounding_text != 'auto':
277
  language_instruction += "; " + grounding_text
 
278
  return language_instruction
279
 
280