mkshing commited on
Commit
0ecbb8c
·
verified ·
1 Parent(s): ec826a6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import numpy as np
6
+ import spaces
7
+ import torch
8
+
9
+ from evoukiyoe_v1 import load_evoukiyoe
10
+
11
+
12
+ DESCRIPTION = """# 🐟 EvoUkiyo-e
13
+ 🤗 [モデル一覧](https://huggingface.co/SakanaAI) | 📚 [技術レポート](https://arxiv.org/abs/2403.13187) | 📝 [ブログ](https://sakana.ai/evosdxl-jp/) | 🐦 [Twitter](https://twitter.com/SakanaAILabs)
14
+
15
+ [EvoUkiyo-e](https://huggingface.co/SakanaAI/EvoUkiyo-e-v1)は[Sakana AI](https://sakana.ai/)が教育目的で開発した日本特化の高速な画像生成モデルです。
16
+ 入力した日本語プロンプトに沿った画像を生成することができます。より詳しくは、上記のブログをご参照ください。
17
+ """
18
+ if not torch.cuda.is_available():
19
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
20
+
21
+ MAX_SEED = np.iinfo(np.int32).max
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ NUM_IMAGES_PER_PROMPT = 1
26
+ SAFETY_CHECKER = True
27
+ if SAFETY_CHECKER:
28
+ from safety_checker import StableDiffusionSafetyChecker
29
+ from transformers import CLIPFeatureExtractor
30
+
31
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
32
+ "CompVis/stable-diffusion-safety-checker"
33
+ ).to(device)
34
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(
35
+ "openai/clip-vit-base-patch32"
36
+ )
37
+
38
+ def check_nsfw_images(
39
+ images: list[Image.Image],
40
+ ) -> tuple[list[Image.Image], list[bool]]:
41
+ safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
42
+ has_nsfw_concepts = safety_checker(
43
+ images=[images], clip_input=safety_checker_input.pixel_values.to(device)
44
+ )
45
+
46
+ return images, has_nsfw_concepts
47
+
48
+
49
+ pipe = load_evoukiyoe("cpu").to(device)
50
+
51
+
52
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
53
+ if randomize_seed:
54
+ seed = random.randint(0, MAX_SEED)
55
+ return seed
56
+
57
+
58
+ @spaces.GPU
59
+ @torch.inference_mode()
60
+ def generate(
61
+ prompt: str,
62
+ seed: int = 0,
63
+ randomize_seed: bool = False,
64
+ progress=gr.Progress(track_tqdm=True),
65
+ ):
66
+ pipe.to(device)
67
+ seed = int(randomize_seed_fn(seed, randomize_seed))
68
+ generator = torch.Generator().manual_seed(seed)
69
+
70
+ images = pipe(
71
+ prompt=prompt + "輻の浮世絵。",
72
+ width=1024,
73
+ height=1024,
74
+ guidance_scale=8.0,
75
+ num_inference_steps=40,
76
+ generator=generator,
77
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
78
+ output_type="pil",
79
+ ).images
80
+
81
+ if SAFETY_CHECKER:
82
+ images, has_nsfw_concepts = check_nsfw_images(images)
83
+ if any(has_nsfw_concepts):
84
+ gr.Warning("NSFW content detected.")
85
+ return Image.new("RGB", (512, 512), "WHITE"), seed
86
+ return images[0], seed
87
+
88
+
89
+ examples = [
90
+ ["魚が泳いでいる。"],
91
+ ["熊が本を読んでいる。"],
92
+ ["猫が畳の上で寝ている。"],
93
+ ["象が刀を持っている。"],
94
+ ["男性と女性が戦っている。"],
95
+ ["富士山、桜の木、川と人々の風景。"],
96
+ ]
97
+
98
+ css = """
99
+ .gradio-container{max-width: 690px !important}
100
+ h1{text-align:center}
101
+ """
102
+ with gr.Blocks(css=css) as demo:
103
+ gr.Markdown(DESCRIPTION)
104
+ with gr.Group():
105
+ with gr.Row():
106
+ prompt = gr.Textbox(placeholder="日本語でプロンプトを入力してください。", show_label=False, scale=8)
107
+ submit = gr.Button(scale=0)
108
+ result = gr.Image(label="EvoUkiyo-eからの生成結果", show_label=False)
109
+ with gr.Accordion("詳細設定", open=False):
110
+ seed = gr.Slider(label="シード値", minimum=0, maximum=MAX_SEED, step=1, value=0)
111
+ randomize_seed = gr.Checkbox(label="ランダムにシード値を決定", value=True)
112
+ gr.Examples(examples=examples, inputs=prompt, outputs=[result, seed], fn=generate)
113
+ gr.on(
114
+ triggers=[
115
+ prompt.submit,
116
+ submit.click,
117
+ ],
118
+ fn=generate,
119
+ inputs=[
120
+ prompt,
121
+ seed,
122
+ randomize_seed,
123
+ ],
124
+ outputs=[result, seed],
125
+ api_name="run",
126
+ )
127
+ gr.Markdown("""⚠️ 本モデルは実験段階のプロトタイプであり、教育および研究開発の目的でのみ提供されています。商用利用や、障害が重大な影響を及ぼす可能性のある環境(ミッションクリティカルな環境)での使用には適していません。
128
+ 本モデルの使用は、利用者の自己責任で行われ、その性能や結果については何ら保証されません。
129
+ Sakana AIは、本モデルの使用によって生じた直接的または間接的な損失に対して、結果に関わらず、一切の責任を負いません。
130
+ 利用者は、本モデルの使用に伴うリスクを十分に理解し、自身の判断で使用することが必要です。""")
131
+
132
+ demo.queue().launch()