chaojiemao commited on
Commit
a2decc8
·
verified ·
1 Parent(s): 868bbb0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -13
app.py CHANGED
@@ -25,8 +25,7 @@ import torch
25
  import transformers
26
  from PIL import Image
27
  from transformers import AutoModel, AutoTokenizer
28
- import model
29
- from ace_inference import ACEInference
30
  from scepter.modules.utils.config import Config
31
  from scepter.modules.utils.directory import get_md5
32
  from scepter.modules.utils.file_system import FS
@@ -49,6 +48,9 @@ chat_sty = '\U0001F4AC' # 💬
49
  video_sty = '\U0001f3a5' # 🎥
50
 
51
  lock = threading.Lock()
 
 
 
52
 
53
 
54
  class ChatBotUI(object):
@@ -94,9 +96,10 @@ class ChatBotUI(object):
94
  assert len(self.model_choices) > 0
95
  if self.default_model_name == "": self.default_model_name = list(self.model_choices.keys())[0]
96
  self.model_name = self.default_model_name
97
- self.pipe = ACEInference()
98
- self.pipe.init_from_cfg(self.model_choices[self.default_model_name])
99
- subprocess.run(shlex.split(f'rm -rf {local_folder}'))
 
100
  self.max_msgs = 20
101
  self.enable_i2v = cfg.get('ENABLE_I2V', False)
102
  self.gradio_version = version('gradio')
@@ -540,8 +543,11 @@ class ChatBotUI(object):
540
  lock.acquire()
541
  del self.pipe
542
  torch.cuda.empty_cache()
543
- self.pipe = ACEInference()
544
- self.pipe.init_from_cfg(self.model_choices[model_name])
 
 
 
545
  self.model_name = model_name
546
  lock.release()
547
 
@@ -829,7 +835,8 @@ class ChatBotUI(object):
829
  edit_image = None
830
  edit_image_mask = None
831
  edit_task = ''
832
-
 
833
  print(new_message)
834
  imgs = self.pipe(
835
  image=edit_image,
@@ -896,9 +903,9 @@ class ChatBotUI(object):
896
  }
897
 
898
  buffered = io.BytesIO()
899
- img.convert('RGB').save(buffered, format='PNG')
900
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
901
- img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
902
 
903
  history.append(
904
  (message,
@@ -1048,17 +1055,17 @@ class ChatBotUI(object):
1048
 
1049
  img = imgs[0]
1050
  buffered = io.BytesIO()
1051
- img.convert('RGB').save(buffered, format='PNG')
1052
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1053
  img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
1054
  history = [(prompt,
1055
  f'{pre_info} The generated image is:\n {img_str}')]
1056
 
1057
  img_id = get_md5(img_b64)[:12]
1058
- save_path = os.path.join(self.cache_dir, f'{img_id}.png')
1059
  img.convert('RGB').save(save_path)
1060
 
1061
- return self.get_history(history), gr.update(value=''), gr.update(
1062
  visible=False), gr.update(value=save_path), gr.update(value=-1)
1063
 
1064
  with self.eg:
 
25
  import transformers
26
  from PIL import Image
27
  from transformers import AutoModel, AutoTokenizer
28
+ from ace_flux_inference import FluxACEInference
 
29
  from scepter.modules.utils.config import Config
30
  from scepter.modules.utils.directory import get_md5
31
  from scepter.modules.utils.file_system import FS
 
48
  video_sty = '\U0001f3a5' # 🎥
49
 
50
  lock = threading.Lock()
51
+ inference_dict = {
52
+ "ACE_FLUX": FluxACEInference,
53
+ }
54
 
55
 
56
  class ChatBotUI(object):
 
96
  assert len(self.model_choices) > 0
97
  if self.default_model_name == "": self.default_model_name = list(self.model_choices.keys())[0]
98
  self.model_name = self.default_model_name
99
+ pipe_cfg = self.model_choices[self.default_model_name]
100
+ infer_name = pipe_cfg.get("INFERENCE_TYPE", "ACE")
101
+ self.pipe = inference_dict[infer_name]()
102
+ self.pipe.init_from_cfg(pipe_cfg)
103
  self.max_msgs = 20
104
  self.enable_i2v = cfg.get('ENABLE_I2V', False)
105
  self.gradio_version = version('gradio')
 
543
  lock.acquire()
544
  del self.pipe
545
  torch.cuda.empty_cache()
546
+ torch.cuda.ipc_collect()
547
+ pipe_cfg = self.model_choices[model_name]
548
+ infer_name = pipe_cfg.get("INFERENCE_TYPE", "ACE")
549
+ self.pipe = inference_dict[infer_name]()
550
+ self.pipe.init_from_cfg(pipe_cfg)
551
  self.model_name = model_name
552
  lock.release()
553
 
 
835
  edit_image = None
836
  edit_image_mask = None
837
  edit_task = ''
838
+ if new_message == "":
839
+ new_message = "a beautiful girl wear a skirt."
840
  print(new_message)
841
  imgs = self.pipe(
842
  image=edit_image,
 
903
  }
904
 
905
  buffered = io.BytesIO()
906
+ img.convert('RGB').save(buffered, format='JPEG')
907
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
908
+ img_str = f'<img src="data:image/jpg;base64,{img_b64}" style="pointer-events: none;">'
909
 
910
  history.append(
911
  (message,
 
1055
 
1056
  img = imgs[0]
1057
  buffered = io.BytesIO()
1058
+ img.convert('RGB').save(buffered, format='JPEG')
1059
  img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1060
  img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
1061
  history = [(prompt,
1062
  f'{pre_info} The generated image is:\n {img_str}')]
1063
 
1064
  img_id = get_md5(img_b64)[:12]
1065
+ save_path = os.path.join(self.cache_dir, f'{img_id}.jpg')
1066
  img.convert('RGB').save(save_path)
1067
 
1068
+ return self.get_history(history), gr.update(value=prompt), gr.update(
1069
  visible=False), gr.update(value=save_path), gr.update(value=-1)
1070
 
1071
  with self.eg: