Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -25,8 +25,7 @@ import torch
|
|
25 |
import transformers
|
26 |
from PIL import Image
|
27 |
from transformers import AutoModel, AutoTokenizer
|
28 |
-
import
|
29 |
-
from ace_inference import ACEInference
|
30 |
from scepter.modules.utils.config import Config
|
31 |
from scepter.modules.utils.directory import get_md5
|
32 |
from scepter.modules.utils.file_system import FS
|
@@ -49,6 +48,9 @@ chat_sty = '\U0001F4AC' # 💬
|
|
49 |
video_sty = '\U0001f3a5' # 🎥
|
50 |
|
51 |
lock = threading.Lock()
|
|
|
|
|
|
|
52 |
|
53 |
|
54 |
class ChatBotUI(object):
|
@@ -94,9 +96,10 @@ class ChatBotUI(object):
|
|
94 |
assert len(self.model_choices) > 0
|
95 |
if self.default_model_name == "": self.default_model_name = list(self.model_choices.keys())[0]
|
96 |
self.model_name = self.default_model_name
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
100 |
self.max_msgs = 20
|
101 |
self.enable_i2v = cfg.get('ENABLE_I2V', False)
|
102 |
self.gradio_version = version('gradio')
|
@@ -540,8 +543,11 @@ class ChatBotUI(object):
|
|
540 |
lock.acquire()
|
541 |
del self.pipe
|
542 |
torch.cuda.empty_cache()
|
543 |
-
|
544 |
-
self.
|
|
|
|
|
|
|
545 |
self.model_name = model_name
|
546 |
lock.release()
|
547 |
|
@@ -829,7 +835,8 @@ class ChatBotUI(object):
|
|
829 |
edit_image = None
|
830 |
edit_image_mask = None
|
831 |
edit_task = ''
|
832 |
-
|
|
|
833 |
print(new_message)
|
834 |
imgs = self.pipe(
|
835 |
image=edit_image,
|
@@ -896,9 +903,9 @@ class ChatBotUI(object):
|
|
896 |
}
|
897 |
|
898 |
buffered = io.BytesIO()
|
899 |
-
img.convert('RGB').save(buffered, format='
|
900 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
901 |
-
img_str = f'<img src="data:image/
|
902 |
|
903 |
history.append(
|
904 |
(message,
|
@@ -1048,17 +1055,17 @@ class ChatBotUI(object):
|
|
1048 |
|
1049 |
img = imgs[0]
|
1050 |
buffered = io.BytesIO()
|
1051 |
-
img.convert('RGB').save(buffered, format='
|
1052 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1053 |
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
1054 |
history = [(prompt,
|
1055 |
f'{pre_info} The generated image is:\n {img_str}')]
|
1056 |
|
1057 |
img_id = get_md5(img_b64)[:12]
|
1058 |
-
save_path = os.path.join(self.cache_dir, f'{img_id}.
|
1059 |
img.convert('RGB').save(save_path)
|
1060 |
|
1061 |
-
return self.get_history(history), gr.update(value=
|
1062 |
visible=False), gr.update(value=save_path), gr.update(value=-1)
|
1063 |
|
1064 |
with self.eg:
|
|
|
25 |
import transformers
|
26 |
from PIL import Image
|
27 |
from transformers import AutoModel, AutoTokenizer
|
28 |
+
from ace_flux_inference import FluxACEInference
|
|
|
29 |
from scepter.modules.utils.config import Config
|
30 |
from scepter.modules.utils.directory import get_md5
|
31 |
from scepter.modules.utils.file_system import FS
|
|
|
48 |
video_sty = '\U0001f3a5' # 🎥
|
49 |
|
50 |
lock = threading.Lock()
|
51 |
+
inference_dict = {
|
52 |
+
"ACE_FLUX": FluxACEInference,
|
53 |
+
}
|
54 |
|
55 |
|
56 |
class ChatBotUI(object):
|
|
|
96 |
assert len(self.model_choices) > 0
|
97 |
if self.default_model_name == "": self.default_model_name = list(self.model_choices.keys())[0]
|
98 |
self.model_name = self.default_model_name
|
99 |
+
pipe_cfg = self.model_choices[self.default_model_name]
|
100 |
+
infer_name = pipe_cfg.get("INFERENCE_TYPE", "ACE")
|
101 |
+
self.pipe = inference_dict[infer_name]()
|
102 |
+
self.pipe.init_from_cfg(pipe_cfg)
|
103 |
self.max_msgs = 20
|
104 |
self.enable_i2v = cfg.get('ENABLE_I2V', False)
|
105 |
self.gradio_version = version('gradio')
|
|
|
543 |
lock.acquire()
|
544 |
del self.pipe
|
545 |
torch.cuda.empty_cache()
|
546 |
+
torch.cuda.ipc_collect()
|
547 |
+
pipe_cfg = self.model_choices[model_name]
|
548 |
+
infer_name = pipe_cfg.get("INFERENCE_TYPE", "ACE")
|
549 |
+
self.pipe = inference_dict[infer_name]()
|
550 |
+
self.pipe.init_from_cfg(pipe_cfg)
|
551 |
self.model_name = model_name
|
552 |
lock.release()
|
553 |
|
|
|
835 |
edit_image = None
|
836 |
edit_image_mask = None
|
837 |
edit_task = ''
|
838 |
+
if new_message == "":
|
839 |
+
new_message = "a beautiful girl wear a skirt."
|
840 |
print(new_message)
|
841 |
imgs = self.pipe(
|
842 |
image=edit_image,
|
|
|
903 |
}
|
904 |
|
905 |
buffered = io.BytesIO()
|
906 |
+
img.convert('RGB').save(buffered, format='JPEG')
|
907 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
908 |
+
img_str = f'<img src="data:image/jpg;base64,{img_b64}" style="pointer-events: none;">'
|
909 |
|
910 |
history.append(
|
911 |
(message,
|
|
|
1055 |
|
1056 |
img = imgs[0]
|
1057 |
buffered = io.BytesIO()
|
1058 |
+
img.convert('RGB').save(buffered, format='JPEG')
|
1059 |
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1060 |
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
1061 |
history = [(prompt,
|
1062 |
f'{pre_info} The generated image is:\n {img_str}')]
|
1063 |
|
1064 |
img_id = get_md5(img_b64)[:12]
|
1065 |
+
save_path = os.path.join(self.cache_dir, f'{img_id}.jpg')
|
1066 |
img.convert('RGB').save(save_path)
|
1067 |
|
1068 |
+
return self.get_history(history), gr.update(value=prompt), gr.update(
|
1069 |
visible=False), gr.update(value=save_path), gr.update(value=-1)
|
1070 |
|
1071 |
with self.eg:
|