Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
-
import os
|
2 |
import spaces
|
|
|
3 |
import time
|
4 |
import gradio as gr
|
5 |
import torch
|
@@ -8,9 +9,11 @@ from torchvision import transforms
|
|
8 |
from dataclasses import dataclass
|
9 |
import math
|
10 |
from typing import Callable
|
|
|
11 |
from tqdm import tqdm
|
12 |
import bitsandbytes as bnb
|
13 |
from bitsandbytes.nn.modules import Params4bit, QuantState
|
|
|
14 |
import torch
|
15 |
import random
|
16 |
from einops import rearrange, repeat
|
@@ -18,8 +21,11 @@ from diffusers import AutoencoderKL
|
|
18 |
from torch import Tensor, nn
|
19 |
from transformers import CLIPTextModel, CLIPTokenizer
|
20 |
from transformers import T5EncoderModel, T5Tokenizer
|
21 |
-
from
|
22 |
-
from
|
|
|
|
|
|
|
23 |
|
24 |
class HFEmbedder(nn.Module):
|
25 |
def __init__(self, version: str, max_length: int, **hf_kwargs):
|
@@ -54,6 +60,7 @@ class HFEmbedder(nn.Module):
|
|
54 |
output_hidden_states=False,
|
55 |
)
|
56 |
return outputs[self.output_key]
|
|
|
57 |
|
58 |
device = "cuda"
|
59 |
t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16).to(device)
|
@@ -63,6 +70,9 @@ ae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="va
|
|
63 |
# freeze(t5)
|
64 |
|
65 |
|
|
|
|
|
|
|
66 |
def functional_linear_4bits(x, weight, bias):
|
67 |
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
|
68 |
out = out.to(x)
|
@@ -200,6 +210,9 @@ class Linear(ForgeLoader4Bit):
|
|
200 |
nn.Linear = Linear
|
201 |
|
202 |
|
|
|
|
|
|
|
203 |
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
204 |
q, k = apply_rope(q, k, pe)
|
205 |
|
@@ -724,6 +737,8 @@ def get_image(image) -> torch.Tensor | None:
|
|
724 |
return img[None, ...]
|
725 |
|
726 |
|
|
|
|
|
727 |
|
728 |
from huggingface_hub import hf_hub_download
|
729 |
from safetensors.torch import load_file
|
@@ -734,157 +749,43 @@ model = Flux().to(dtype=torch.bfloat16, device="cuda")
|
|
734 |
result = model.load_state_dict(sd)
|
735 |
model_zero_init = False
|
736 |
|
737 |
-
|
738 |
-
#
|
739 |
-
TRANSLATORS = {
|
740 |
-
|
741 |
-
"Korean": "Helsinki-NLP/opus-mt-ko-en",
|
742 |
-
"Japanese": "Helsinki-NLP/opus-mt-ja-en",
|
743 |
-
"Chinese": "Helsinki-NLP/opus-mt-zh-en",
|
744 |
-
"Russian": "Helsinki-NLP/opus-mt-ru-en",
|
745 |
-
"Spanish": "Helsinki-NLP/opus-mt-es-en",
|
746 |
-
"French": "Helsinki-NLP/opus-mt-fr-en",
|
747 |
-
"Arabic": "Helsinki-NLP/opus-mt-ar-en",
|
748 |
-
"Bengali": "Helsinki-NLP/opus-mt-bn-en",
|
749 |
-
"Estonian": "Helsinki-NLP/opus-mt-et-en",
|
750 |
-
"Polish": "Helsinki-NLP/opus-mt-pl-en",
|
751 |
-
"Swedish": "Helsinki-NLP/opus-mt-sv-en",
|
752 |
-
"Thai": "Helsinki-NLP/opus-mt-th-en",
|
753 |
-
"Urdu": "Helsinki-NLP/opus-mt-ur-en",
|
754 |
-
"Bulgarian": "Helsinki-NLP/opus-mt-bg-en",
|
755 |
-
"Catalan": "Helsinki-NLP/opus-mt-ca-en",
|
756 |
-
"Czech": "Helsinki-NLP/opus-mt-cs-en",
|
757 |
-
"Azerbaijani": "Helsinki-NLP/opus-mt-az-en",
|
758 |
-
"Basque": "Helsinki-NLP/opus-mt-bat-en",
|
759 |
-
"Bicolano": "Helsinki-NLP/opus-mt-bcl-en",
|
760 |
-
"Bemba": "Helsinki-NLP/opus-mt-bem-en",
|
761 |
-
"Berber": "Helsinki-NLP/opus-mt-ber-en",
|
762 |
-
"Bislama": "Helsinki-NLP/opus-mt-bi-en",
|
763 |
-
"Bantu": "Helsinki-NLP/opus-mt-bnt-en",
|
764 |
-
"Brazilian Sign Language": "Helsinki-NLP/opus-mt-bzs-en",
|
765 |
-
"Caucasian": "Helsinki-NLP/opus-mt-cau-en",
|
766 |
-
"Cebuano": "Helsinki-NLP/opus-mt-ceb-en",
|
767 |
-
"Celtic": "Helsinki-NLP/opus-mt-cel-en",
|
768 |
-
"Chuukese": "Helsinki-NLP/opus-mt-chk-en",
|
769 |
-
"Creoles and pidgins (French)": "Helsinki-NLP/opus-mt-cpf-en",
|
770 |
-
"Seychelles Creole": "Helsinki-NLP/opus-mt-crs-en",
|
771 |
-
"American Sign Language": "Helsinki-NLP/opus-mt-ase-en",
|
772 |
-
"Artificial Language": "Helsinki-NLP/opus-mt-art-en",
|
773 |
-
"Atlantic-Congo": "Helsinki-NLP/opus-mt-alv-en",
|
774 |
-
"Afroasiatic": "Helsinki-NLP/opus-mt-afa-en",
|
775 |
-
"Afrikaans": "Helsinki-NLP/opus-mt-af-en",
|
776 |
-
"Austroasiatic": "Helsinki-NLP/opus-mt-aav-en"
|
777 |
-
}
|
778 |
-
|
779 |
-
translators_cache = {}
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
# 모델 캐시 디렉토리 설정
|
784 |
-
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
|
785 |
-
|
786 |
-
def download_model(model_name):
|
787 |
-
"""모델을 미리 다운로드"""
|
788 |
-
try:
|
789 |
-
cache_dir = os.path.join('/tmp/transformers_cache', model_name.split('/')[-1])
|
790 |
-
snapshot_download(
|
791 |
-
repo_id=model_name,
|
792 |
-
cache_dir=cache_dir,
|
793 |
-
local_files_only=False
|
794 |
-
)
|
795 |
-
return cache_dir
|
796 |
-
except Exception as e:
|
797 |
-
print(f"Error downloading model {model_name}: {e}")
|
798 |
-
return None
|
799 |
-
|
800 |
-
def get_translator(lang):
|
801 |
-
"""번역기 초기화 및 반환"""
|
802 |
-
if lang == "English":
|
803 |
-
return None
|
804 |
-
|
805 |
-
if lang not in translators_cache:
|
806 |
-
try:
|
807 |
-
model_name = TRANSLATORS[lang]
|
808 |
-
|
809 |
-
# pipeline 사용 대신 직접 모델 로드
|
810 |
-
tokenizer = MarianTokenizer.from_pretrained(model_name)
|
811 |
-
model = MarianMTModel.from_pretrained(model_name)
|
812 |
-
|
813 |
-
# CPU에서 실행
|
814 |
-
model = model.to("cpu").eval()
|
815 |
-
|
816 |
-
translators_cache[lang] = {
|
817 |
-
"model": model,
|
818 |
-
"tokenizer": tokenizer
|
819 |
-
}
|
820 |
-
print(f"Successfully loaded translator for {lang}")
|
821 |
-
|
822 |
-
except Exception as e:
|
823 |
-
print(f"Error loading translator for {lang}: {e}")
|
824 |
-
return None
|
825 |
-
|
826 |
-
return translators_cache[lang]
|
827 |
-
|
828 |
-
def translate_text(text, translator_info):
|
829 |
-
"""번역 수행"""
|
830 |
-
if translator_info is None:
|
831 |
-
return text
|
832 |
-
|
833 |
-
try:
|
834 |
-
tokenizer = translator_info["tokenizer"]
|
835 |
-
model = translator_info["model"]
|
836 |
-
|
837 |
-
# 입력 텍스트 전처리
|
838 |
-
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
839 |
-
|
840 |
-
# 번역 수행
|
841 |
-
with torch.no_grad():
|
842 |
-
outputs = model.generate(**inputs)
|
843 |
-
|
844 |
-
# 번역 결과 디코딩
|
845 |
-
translated = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
846 |
-
|
847 |
-
print(f"Original text: {text}")
|
848 |
-
print(f"Translated text: {translated}")
|
849 |
-
|
850 |
-
return translated
|
851 |
-
|
852 |
-
except Exception as e:
|
853 |
-
print(f"Translation error: {e}")
|
854 |
-
return text
|
855 |
|
856 |
|
857 |
@spaces.GPU
|
858 |
@torch.no_grad()
|
859 |
def generate_image(
|
860 |
-
prompt,
|
861 |
do_img2img, init_image, image2image_strength, resize_img,
|
862 |
progress=gr.Progress(track_tqdm=True),
|
863 |
):
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
translated_prompt = translate_text(prompt, translator_info)
|
870 |
-
print(f"Using translated prompt: {translated_prompt}")
|
871 |
-
else:
|
872 |
-
print(f"No translator available for {source_lang}, using original prompt")
|
873 |
-
translated_prompt = prompt
|
874 |
-
else:
|
875 |
-
translated_prompt = prompt
|
876 |
-
except Exception as e:
|
877 |
-
print(f"Translation failed: {e}")
|
878 |
-
translated_prompt = prompt
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
883 |
if seed == 0:
|
884 |
seed = int(random.random() * 1000000)
|
885 |
|
886 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
887 |
torch_device = torch.device(device)
|
|
|
|
|
888 |
|
889 |
global model, model_zero_init
|
890 |
if not model_zero_init:
|
@@ -901,11 +802,10 @@ def generate_image(
|
|
901 |
height = init_image.shape[-2]
|
902 |
width = init_image.shape[-1]
|
903 |
init_image = ae.encode(init_image.to(torch_device).to(torch.bfloat16)).latent_dist.sample()
|
904 |
-
init_image =
|
905 |
|
906 |
generator = torch.Generator(device=device).manual_seed(seed)
|
907 |
-
x = torch.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16),
|
908 |
-
device=device, dtype=torch.bfloat16, generator=generator)
|
909 |
|
910 |
num_steps = inference_steps
|
911 |
timesteps = get_schedule(num_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
|
@@ -916,18 +816,22 @@ def generate_image(
|
|
916 |
timesteps = timesteps[t_idx:]
|
917 |
x = t * x + (1.0 - t) * init_image.to(x.dtype)
|
918 |
|
919 |
-
inp = prepare(t5=t5, clip=clip, img=x, prompt=
|
920 |
x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
|
921 |
|
|
|
|
|
|
|
922 |
x = unpack(x.float(), height, width)
|
923 |
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
|
924 |
-
x = (x / ae.config.scaling_factor) + ae.config.shift_factor
|
925 |
x = ae.decode(x).sample
|
926 |
|
927 |
x = x.clamp(-1, 1)
|
928 |
x = rearrange(x[0], "c h w -> h w c")
|
929 |
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
930 |
|
|
|
931 |
return img, seed, translated_prompt
|
932 |
|
933 |
css = """
|
@@ -936,21 +840,14 @@ footer {
|
|
936 |
}
|
937 |
"""
|
938 |
|
|
|
939 |
def create_demo():
|
940 |
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
|
|
|
941 |
with gr.Row():
|
942 |
with gr.Column():
|
943 |
-
|
944 |
-
choices=["English"] + sorted(list(TRANSLATORS.keys())),
|
945 |
-
value="English",
|
946 |
-
label="Source Language"
|
947 |
-
)
|
948 |
|
949 |
-
prompt = gr.Textbox(
|
950 |
-
label="Prompt",
|
951 |
-
value="A beautiful landscape"
|
952 |
-
)
|
953 |
-
|
954 |
width = gr.Slider(minimum=128, maximum=2048, step=64, label="Width", value=768)
|
955 |
height = gr.Slider(minimum=128, maximum=2048, step=64, label="Height", value=768)
|
956 |
guidance = gr.Slider(minimum=1.0, maximum=5.0, step=0.1, label="Guidance", value=3.5)
|
@@ -964,44 +861,13 @@ def create_demo():
|
|
964 |
seed = gr.Number(label="Seed", precision=-1)
|
965 |
do_img2img = gr.Checkbox(label="Image to Image", value=False)
|
966 |
init_image = gr.Image(label="Input Image", visible=False)
|
967 |
-
image2image_strength = gr.Slider(
|
968 |
-
minimum=0.0, maximum=1.0, step=0.01,
|
969 |
-
label="Noising strength", value=0.8, visible=False
|
970 |
-
)
|
971 |
resize_img = gr.Checkbox(label="Resize image", value=True, visible=False)
|
972 |
generate_button = gr.Button("Generate")
|
973 |
|
974 |
with gr.Column():
|
975 |
output_image = gr.Image(label="Generated Image")
|
976 |
output_seed = gr.Text(label="Used Seed")
|
977 |
-
translated_prompt = gr.Text(label="Translated Prompt")
|
978 |
-
|
979 |
-
# 다국어 예제
|
980 |
-
examples = [
|
981 |
-
# English
|
982 |
-
["A beautiful sunset over mountains", "English", 768, 768, 3.5, 30, 0, False, None, 0.8, True],
|
983 |
-
# Korean
|
984 |
-
["벚꽃이 흩날리는 서울의 봄 풍경", "Korean", 768, 768, 3.5, 30, 0, False, None, 0.8, True],
|
985 |
-
# Japanese
|
986 |
-
["富士山と桜の美しい風景", "Japanese", 768, 768, 3.5, 30, 0, False, None, 0.8, True],
|
987 |
-
# Chinese
|
988 |
-
["长城日落的壮丽景色", "Chinese", 768, 768, 3.5, 30, 0, False, None, 0.8, True],
|
989 |
-
# Spanish
|
990 |
-
["Un hermoso atardecer en la playa", "Spanish", 768, 768, 3.5, 30, 0, False, None, 0.8, True]
|
991 |
-
]
|
992 |
-
|
993 |
-
gr.Examples(
|
994 |
-
examples=examples,
|
995 |
-
inputs=[
|
996 |
-
prompt, source_lang, width, height, guidance, inference_steps,
|
997 |
-
seed, do_img2img, init_image, image2image_strength, resize_img
|
998 |
-
],
|
999 |
-
outputs=[output_image, output_seed, translated_prompt],
|
1000 |
-
fn=generate_image,
|
1001 |
-
cache_examples=True
|
1002 |
-
)
|
1003 |
-
|
1004 |
-
|
1005 |
|
1006 |
do_img2img.change(
|
1007 |
fn=lambda x: [gr.update(visible=x), gr.update(visible=x), gr.update(visible=x)],
|
@@ -1011,16 +877,18 @@ def create_demo():
|
|
1011 |
|
1012 |
generate_button.click(
|
1013 |
fn=generate_image,
|
1014 |
-
inputs=[
|
1015 |
-
|
1016 |
-
seed, do_img2img, init_image, image2image_strength, resize_img
|
1017 |
-
],
|
1018 |
-
outputs=[output_image, output_seed, translated_prompt]
|
1019 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
1020 |
|
1021 |
return demo
|
1022 |
|
1023 |
if __name__ == "__main__":
|
1024 |
-
print("Starting demo...")
|
1025 |
demo = create_demo()
|
1026 |
-
demo.launch(
|
|
|
1 |
+
# import os
|
2 |
import spaces
|
3 |
+
|
4 |
import time
|
5 |
import gradio as gr
|
6 |
import torch
|
|
|
9 |
from dataclasses import dataclass
|
10 |
import math
|
11 |
from typing import Callable
|
12 |
+
|
13 |
from tqdm import tqdm
|
14 |
import bitsandbytes as bnb
|
15 |
from bitsandbytes.nn.modules import Params4bit, QuantState
|
16 |
+
|
17 |
import torch
|
18 |
import random
|
19 |
from einops import rearrange, repeat
|
|
|
21 |
from torch import Tensor, nn
|
22 |
from transformers import CLIPTextModel, CLIPTokenizer
|
23 |
from transformers import T5EncoderModel, T5Tokenizer
|
24 |
+
# from optimum.quanto import freeze, qfloat8, quantize
|
25 |
+
from transformers import pipeline
|
26 |
+
|
27 |
+
ko_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
|
28 |
+
ja_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en")
|
29 |
|
30 |
class HFEmbedder(nn.Module):
|
31 |
def __init__(self, version: str, max_length: int, **hf_kwargs):
|
|
|
60 |
output_hidden_states=False,
|
61 |
)
|
62 |
return outputs[self.output_key]
|
63 |
+
|
64 |
|
65 |
device = "cuda"
|
66 |
t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16).to(device)
|
|
|
70 |
# freeze(t5)
|
71 |
|
72 |
|
73 |
+
# ---------------- NF4 ----------------
|
74 |
+
|
75 |
+
|
76 |
def functional_linear_4bits(x, weight, bias):
|
77 |
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
|
78 |
out = out.to(x)
|
|
|
210 |
nn.Linear = Linear
|
211 |
|
212 |
|
213 |
+
# ---------------- Model ----------------
|
214 |
+
|
215 |
+
|
216 |
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
217 |
q, k = apply_rope(q, k, pe)
|
218 |
|
|
|
737 |
return img[None, ...]
|
738 |
|
739 |
|
740 |
+
# ---------------- Demo ----------------
|
741 |
+
|
742 |
|
743 |
from huggingface_hub import hf_hub_download
|
744 |
from safetensors.torch import load_file
|
|
|
749 |
result = model.load_state_dict(sd)
|
750 |
model_zero_init = False
|
751 |
|
752 |
+
# model = Flux().to(dtype=torch.bfloat16, device="cuda")
|
753 |
+
# result = model.load_state_dict(load_file("/storage/dev/nyanko/flux-dev/flux1-dev.sft"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
754 |
|
755 |
|
756 |
@spaces.GPU
|
757 |
@torch.no_grad()
|
758 |
def generate_image(
|
759 |
+
prompt, width, height, guidance, inference_steps, seed,
|
760 |
do_img2img, init_image, image2image_strength, resize_img,
|
761 |
progress=gr.Progress(track_tqdm=True),
|
762 |
):
|
763 |
+
translated_prompt = prompt
|
764 |
+
|
765 |
+
# 한글 또는 일본어 문자 감지
|
766 |
+
def contains_korean(text):
|
767 |
+
return any('\u3131' <= c <= '\u318E' or '\uAC00' <= c <= '\uD7A3' for c in text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
768 |
|
769 |
+
def contains_japanese(text):
|
770 |
+
return any('\u3040' <= c <= '\u309F' or '\u30A0' <= c <= '\u30FF' or '\u4E00' <= c <= '\u9FFF' for c in text)
|
771 |
+
|
772 |
+
# 한글이나 일본어가 있으면 번역
|
773 |
+
if contains_korean(prompt):
|
774 |
+
translated_prompt = ko_translator(prompt, max_length=512)[0]['translation_text']
|
775 |
+
print(f"Translated Korean prompt: {translated_prompt}")
|
776 |
+
prompt = translated_prompt
|
777 |
+
elif contains_japanese(prompt):
|
778 |
+
translated_prompt = ja_translator(prompt, max_length=512)[0]['translation_text']
|
779 |
+
print(f"Translated Japanese prompt: {translated_prompt}")
|
780 |
+
prompt = translated_prompt
|
781 |
+
|
782 |
if seed == 0:
|
783 |
seed = int(random.random() * 1000000)
|
784 |
|
785 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
786 |
torch_device = torch.device(device)
|
787 |
+
|
788 |
+
|
789 |
|
790 |
global model, model_zero_init
|
791 |
if not model_zero_init:
|
|
|
802 |
height = init_image.shape[-2]
|
803 |
width = init_image.shape[-1]
|
804 |
init_image = ae.encode(init_image.to(torch_device).to(torch.bfloat16)).latent_dist.sample()
|
805 |
+
init_image = (init_image - ae.config.shift_factor) * ae.config.scaling_factor
|
806 |
|
807 |
generator = torch.Generator(device=device).manual_seed(seed)
|
808 |
+
x = torch.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), device=device, dtype=torch.bfloat16, generator=generator)
|
|
|
809 |
|
810 |
num_steps = inference_steps
|
811 |
timesteps = get_schedule(num_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
|
|
|
816 |
timesteps = timesteps[t_idx:]
|
817 |
x = t * x + (1.0 - t) * init_image.to(x.dtype)
|
818 |
|
819 |
+
inp = prepare(t5=t5, clip=clip, img=x, prompt=prompt)
|
820 |
x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
|
821 |
|
822 |
+
# with profile(activities=[ProfilerActivity.CPU],record_shapes=True,profile_memory=True) as prof:
|
823 |
+
# print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
|
824 |
+
|
825 |
x = unpack(x.float(), height, width)
|
826 |
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
|
827 |
+
x = x = (x / ae.config.scaling_factor) + ae.config.shift_factor
|
828 |
x = ae.decode(x).sample
|
829 |
|
830 |
x = x.clamp(-1, 1)
|
831 |
x = rearrange(x[0], "c h w -> h w c")
|
832 |
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
833 |
|
834 |
+
|
835 |
return img, seed, translated_prompt
|
836 |
|
837 |
css = """
|
|
|
840 |
}
|
841 |
"""
|
842 |
|
843 |
+
|
844 |
def create_demo():
|
845 |
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
|
846 |
+
|
847 |
with gr.Row():
|
848 |
with gr.Column():
|
849 |
+
prompt = gr.Textbox(label="Prompt(한글 가능)", value="A cute and fluffy golden retriever puppy sitting upright, holding a neatly designed white sign with bold, colorful lettering that reads 'Have a Happy Day!' in cheerful fonts. The puppy has expressive, sparkling eyes, a happy smile, and fluffy ears slightly flopped. The background is a vibrant and sunny meadow with soft-focus flowers, glowing sunlight filtering through the trees, and a warm golden glow that enhances the joyful atmosphere. The sign is framed with small decorative flowers, adding a charming and wholesome touch. Ensure the text on the sign is clear and legible.")
|
|
|
|
|
|
|
|
|
850 |
|
|
|
|
|
|
|
|
|
|
|
851 |
width = gr.Slider(minimum=128, maximum=2048, step=64, label="Width", value=768)
|
852 |
height = gr.Slider(minimum=128, maximum=2048, step=64, label="Height", value=768)
|
853 |
guidance = gr.Slider(minimum=1.0, maximum=5.0, step=0.1, label="Guidance", value=3.5)
|
|
|
861 |
seed = gr.Number(label="Seed", precision=-1)
|
862 |
do_img2img = gr.Checkbox(label="Image to Image", value=False)
|
863 |
init_image = gr.Image(label="Input Image", visible=False)
|
864 |
+
image2image_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Noising strength", value=0.8, visible=False)
|
|
|
|
|
|
|
865 |
resize_img = gr.Checkbox(label="Resize image", value=True, visible=False)
|
866 |
generate_button = gr.Button("Generate")
|
867 |
|
868 |
with gr.Column():
|
869 |
output_image = gr.Image(label="Generated Image")
|
870 |
output_seed = gr.Text(label="Used Seed")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
871 |
|
872 |
do_img2img.change(
|
873 |
fn=lambda x: [gr.update(visible=x), gr.update(visible=x), gr.update(visible=x)],
|
|
|
877 |
|
878 |
generate_button.click(
|
879 |
fn=generate_image,
|
880 |
+
inputs=[prompt, width, height, guidance, inference_steps, seed, do_img2img, init_image, image2image_strength, resize_img],
|
881 |
+
outputs=[output_image, output_seed]
|
|
|
|
|
|
|
882 |
)
|
883 |
+
|
884 |
+
examples = [
|
885 |
+
"a tiny astronaut hatching from an egg on the moon",
|
886 |
+
"a cat holding a sign that says hello world",
|
887 |
+
"an anime illustration of a wiener schnitzel",
|
888 |
+
]
|
889 |
|
890 |
return demo
|
891 |
|
892 |
if __name__ == "__main__":
|
|
|
893 |
demo = create_demo()
|
894 |
+
demo.launch()
|