Spaces:
Runtime error
Runtime error
Ron Au
commited on
Commit
•
81ea62c
1
Parent(s):
b1e2dc7
Initial D
Browse files- .gitignore +4 -0
- Pipfile +17 -0
- Pipfile.lock +0 -0
- README.md +2 -4
- app.py +48 -0
- modules/sprites.py +266 -0
- requirements.txt +80 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
!**/.gitkeep
|
2 |
+
|
3 |
+
cache/*
|
4 |
+
output/*
|
Pipfile
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[[source]]
|
2 |
+
url = "https://pypi.org/simple"
|
3 |
+
verify_ssl = true
|
4 |
+
name = "pypi"
|
5 |
+
|
6 |
+
[packages]
|
7 |
+
diffusers = { version = "==0.7.*", extras = ["torch"] }
|
8 |
+
gradio = "==3.9.*"
|
9 |
+
scipy = "==1.9.*"
|
10 |
+
torch = "==1.13.*"
|
11 |
+
torchvision = "==0.14.*"
|
12 |
+
transformers = "==4.24.*"
|
13 |
+
|
14 |
+
[dev-packages]
|
15 |
+
|
16 |
+
[requires]
|
17 |
+
python_version = "3.10"
|
Pipfile.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
@@ -8,5 +8,3 @@ sdk_version: 3.9
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Stable Diffusion Sprite Sheets
|
3 |
+
emoji: 🚶♀️
|
4 |
colorFrom: purple
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
|
app.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import time
|
4 |
+
import gradio as gr
|
5 |
+
from modules.sprites import generate_sides, build_gifs, build_spritesheet
|
6 |
+
|
7 |
+
|
8 |
+
def generate(prompt, thresh):
|
9 |
+
timestamp = int(time.time())
|
10 |
+
|
11 |
+
sides = generate_sides(prompt, 3)[0]
|
12 |
+
spritesheet = build_spritesheet(sides, prompt, timestamp=timestamp, thresh=thresh)[0]
|
13 |
+
|
14 |
+
filepaths = build_gifs(sides, prompt, save=True, timestamp=timestamp, thresh=thresh)[1]
|
15 |
+
|
16 |
+
return spritesheet, filepaths[0], filepaths[1], filepaths[2], filepaths[3]
|
17 |
+
|
18 |
+
|
19 |
+
demo = gr.Blocks()
|
20 |
+
|
21 |
+
with demo:
|
22 |
+
gr.Markdown("""
|
23 |
+
# Stable Diffusion Sprite Sheets
|
24 |
+
|
25 |
+
Generate a sprite sheet of pixel art character sides and their corresponding walk animations! Checkpoint by [Onodofthenorth](https://huggingface.co/Onodofthenorth/SD_PixelArt_SpriteSheet_Generator). Sprites are 32x32 pixels scaled up to 96x96. NSFW content replaced with blank sprites.
|
26 |
+
""")
|
27 |
+
|
28 |
+
with gr.Row():
|
29 |
+
with gr.Column():
|
30 |
+
prompt = gr.Textbox(label="Prompt", placeholder="Enter text prompt")
|
31 |
+
threshold = gr.Slider(label="Background removal threshold", placeholder="Tweak how strong the background removal is", minimum=0, maximum=255, value=128)
|
32 |
+
|
33 |
+
button = gr.Button("Generate")
|
34 |
+
|
35 |
+
with gr.Box():
|
36 |
+
with gr.Row():
|
37 |
+
spritesheet = gr.Image(label="Sprite Sheet")
|
38 |
+
|
39 |
+
with gr.Row():
|
40 |
+
front = gr.Image(label="Front")
|
41 |
+
back = gr.Image(label="Back")
|
42 |
+
left = gr.Image(label="Left")
|
43 |
+
right = gr.Image(label="Right")
|
44 |
+
|
45 |
+
button.click(fn=generate, inputs=[prompt, threshold], outputs=[spritesheet, front, back, left, right])
|
46 |
+
|
47 |
+
demo.queue()
|
48 |
+
demo.launch(show_api=False)
|
modules/sprites.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
import re
|
4 |
+
import time
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
from typing import Final, List, Optional, Tuple, cast
|
8 |
+
|
9 |
+
from PIL import Image, ImageDraw, ImageEnhance
|
10 |
+
from PIL.Image import Image as PILImage
|
11 |
+
from diffusers import StableDiffusionPipeline
|
12 |
+
|
13 |
+
model_id: Final = "Onodofthenorth/SD_PixelArt_SpriteSheet_Generator"
|
14 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
15 |
+
model_id, torch_dtype=torch.float16, cache_dir="cache"
|
16 |
+
)
|
17 |
+
pipe = pipe.to("cuda")
|
18 |
+
|
19 |
+
sprite_sides: Final = {
|
20 |
+
"front": "PixelArtFSS",
|
21 |
+
"right": "PixelArtRSS",
|
22 |
+
"back": "PixelArtBSS",
|
23 |
+
"left": "PixelArtLSS",
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def torchGenerator(seed: Optional[int], max: int = 1024) -> Tuple[torch.Generator, int]:
|
28 |
+
seed = seed or random.randrange(0, max)
|
29 |
+
|
30 |
+
return torch.Generator("cuda").manual_seed(seed), seed
|
31 |
+
|
32 |
+
|
33 |
+
def generate(
|
34 |
+
prompt: str,
|
35 |
+
sfw_retries: int = 1,
|
36 |
+
seed: Optional[int] = None,
|
37 |
+
) -> PILImage:
|
38 |
+
"""
|
39 |
+
Generate a sprite image from a text description.
|
40 |
+
|
41 |
+
Return a blank image if the model fails to generate a safe image.
|
42 |
+
"""
|
43 |
+
|
44 |
+
generator = torchGenerator(seed)[0]
|
45 |
+
image: PILImage | None = None
|
46 |
+
|
47 |
+
for _ in range(sfw_retries):
|
48 |
+
pipe_output = pipe(prompt, generator=generator, width=512, height=512)
|
49 |
+
image = pipe_output.images[0]
|
50 |
+
|
51 |
+
if not pipe_output.nsfw_content_detected[0]:
|
52 |
+
break
|
53 |
+
|
54 |
+
rand_seed = seed
|
55 |
+
|
56 |
+
while rand_seed == seed:
|
57 |
+
print(f"Regenerating `{prompt}` with different seed.")
|
58 |
+
|
59 |
+
rand_seed = random.randrange(0, 1024)
|
60 |
+
generator = torchGenerator(rand_seed)[0]
|
61 |
+
|
62 |
+
return cast(PILImage, image)
|
63 |
+
|
64 |
+
|
65 |
+
def generate_sides(
|
66 |
+
prompt: str, sfw_retries: int = 1, sides: dict[str, str] = sprite_sides
|
67 |
+
) -> Tuple[dict[str, PILImage], str]:
|
68 |
+
"""
|
69 |
+
Generate sprite images from a text description of different sides.
|
70 |
+
|
71 |
+
If both left and right side specified, duplicate and flip left side as the right side
|
72 |
+
"""
|
73 |
+
|
74 |
+
print(f"Generating sprites for `{prompt}`")
|
75 |
+
|
76 |
+
seed = random.randrange(0, 1024)
|
77 |
+
sprites = {}
|
78 |
+
|
79 |
+
# If both left and right side specified, duplicate and flip left side as the right side
|
80 |
+
for side, label in sides.items():
|
81 |
+
if side == "right" and "left" in sides and "right" in sides:
|
82 |
+
continue
|
83 |
+
|
84 |
+
sprites[side] = generate(f"({prompt}) [nsfw] [photograph] {label}", sfw_retries, seed)
|
85 |
+
|
86 |
+
if "left" in sides and "right" in sides:
|
87 |
+
sprites["right"] = sprites["left"].transpose(Image.Transpose.FLIP_LEFT_RIGHT)
|
88 |
+
|
89 |
+
return sprites, prompt
|
90 |
+
|
91 |
+
|
92 |
+
def clean_sprite(
|
93 |
+
image: PILImage,
|
94 |
+
size: Tuple[int, int] = (192, 192),
|
95 |
+
sharpness: float = 1.5,
|
96 |
+
thresh: int = 128,
|
97 |
+
rescaling: Optional[int] = None,
|
98 |
+
) -> PILImage:
|
99 |
+
"""
|
100 |
+
Process image to be more sprite-like.
|
101 |
+
|
102 |
+
`rescale` will first scale down by value, then up to specified size.
|
103 |
+
"""
|
104 |
+
|
105 |
+
width, height = image.size
|
106 |
+
sharpener = ImageEnhance.Sharpness(image)
|
107 |
+
|
108 |
+
image = sharpener.enhance(sharpness)
|
109 |
+
image = image.convert("RGBA")
|
110 |
+
ImageDraw.floodfill(image, (0, 0), (255, 255, 255, 0), thresh=thresh)
|
111 |
+
|
112 |
+
if type(rescaling) is int:
|
113 |
+
image = image.resize(
|
114 |
+
(int(width / rescaling), int(height / rescaling)),
|
115 |
+
resample=Image.Resampling.NEAREST,
|
116 |
+
)
|
117 |
+
|
118 |
+
image = image.resize(size, resample=Image.Resampling.NEAREST)
|
119 |
+
|
120 |
+
return image
|
121 |
+
|
122 |
+
|
123 |
+
def split_sprites(image: PILImage, size: Tuple[int, int] = (96, 96)) -> List[PILImage]:
|
124 |
+
"""Split sprite image into individual sides."""
|
125 |
+
|
126 |
+
width, height = image.size
|
127 |
+
w, h = size
|
128 |
+
|
129 |
+
# fmt: off
|
130 |
+
frames = [
|
131 |
+
image.crop((
|
132 |
+
0,
|
133 |
+
int(h / 2),
|
134 |
+
int(width / 4),
|
135 |
+
int(height * 0.75),
|
136 |
+
)),
|
137 |
+
image.crop((
|
138 |
+
int(width / 4),
|
139 |
+
int(h / 2),
|
140 |
+
int(width / 4) * 2,
|
141 |
+
int(height * 0.75),
|
142 |
+
)),
|
143 |
+
image.crop((
|
144 |
+
int(width / 4) * 2,
|
145 |
+
int(h / 2),
|
146 |
+
int(width / 4) * 3,
|
147 |
+
int(height * 0.75),
|
148 |
+
)),
|
149 |
+
image.crop((
|
150 |
+
int(width / 4) * 3,
|
151 |
+
int(h / 2),
|
152 |
+
width,
|
153 |
+
int(height * 0.75),
|
154 |
+
)),
|
155 |
+
]
|
156 |
+
# fmt: on
|
157 |
+
|
158 |
+
new_canvas = Image.new("RGBA", size, (255, 255, 255, 0))
|
159 |
+
|
160 |
+
for i in range(len(frames)):
|
161 |
+
canvas = new_canvas.copy()
|
162 |
+
canvas.paste(frames[i], (int(w / 4), 0, int(w * 0.75), h))
|
163 |
+
frames[i] = canvas
|
164 |
+
|
165 |
+
return frames
|
166 |
+
|
167 |
+
|
168 |
+
def build_spritesheet(
|
169 |
+
images: dict[str, PILImage],
|
170 |
+
text: str = "sd_pixelart",
|
171 |
+
sprite_size: Tuple[int, int] = (96, 96),
|
172 |
+
dir: str = "output",
|
173 |
+
save: bool = False,
|
174 |
+
timestamp: Optional[int] = None,
|
175 |
+
thresh: int = 128,
|
176 |
+
) -> Tuple[PILImage, str | None]:
|
177 |
+
"""
|
178 |
+
Build sprite sheet from sides.
|
179 |
+
|
180 |
+
1. Clean and scale each image
|
181 |
+
2. Split each image into individual frames
|
182 |
+
3. Create a new spritesheet canvas for all sides[frames]
|
183 |
+
4. Paste each individial frame onto canvas
|
184 |
+
"""
|
185 |
+
|
186 |
+
frames = {}
|
187 |
+
width, height = sprite_size
|
188 |
+
text = re.sub(r"[^\w()[\]_-]", "", text)
|
189 |
+
filepath = None
|
190 |
+
|
191 |
+
for side, image in images.items():
|
192 |
+
image = clean_sprite(image, (width * 2, height * 2), thresh=thresh)
|
193 |
+
frames[side] = split_sprites(image, sprite_size)
|
194 |
+
|
195 |
+
canvas = Image.new(
|
196 |
+
"RGBA",
|
197 |
+
(width * len(frames["front"]), height * len(frames)),
|
198 |
+
(255, 255, 255, 0),
|
199 |
+
)
|
200 |
+
|
201 |
+
for j in range(len(frames["front"])):
|
202 |
+
for k, side in enumerate(frames):
|
203 |
+
canvas.paste(
|
204 |
+
frames[side][j],
|
205 |
+
(
|
206 |
+
j * width,
|
207 |
+
k * height,
|
208 |
+
j * width + width,
|
209 |
+
k * height + height,
|
210 |
+
),
|
211 |
+
)
|
212 |
+
|
213 |
+
spritesheet = io.BytesIO()
|
214 |
+
canvas.save(spritesheet, "PNG")
|
215 |
+
|
216 |
+
if save:
|
217 |
+
timestamp = timestamp or int(time.time())
|
218 |
+
filepath = os.path.join(dir, f"{timestamp}_{text}.png")
|
219 |
+
canvas.save(filepath)
|
220 |
+
|
221 |
+
return Image.open(spritesheet), filepath
|
222 |
+
|
223 |
+
|
224 |
+
def build_gifs(
|
225 |
+
images: dict[str, PILImage],
|
226 |
+
text: str = "sd_spritesheet",
|
227 |
+
dir: str = "output",
|
228 |
+
duration: int | List[int] | Tuple[int, ...] = (300, 450, 300, 450),
|
229 |
+
save: bool = False,
|
230 |
+
timestamp: Optional[int] = None,
|
231 |
+
thresh: int = 128,
|
232 |
+
) -> Tuple[dict[str, List[PILImage]], List[str] | None]:
|
233 |
+
"""Build animated GIFs from side frames."""
|
234 |
+
|
235 |
+
gifs = {}
|
236 |
+
text = re.sub(r"[^\w()[\]_-]", "", text)
|
237 |
+
filepaths = [] if save else None
|
238 |
+
|
239 |
+
for side, image in images.items():
|
240 |
+
image = clean_sprite(image, thresh=thresh)
|
241 |
+
frames = split_sprites(image)
|
242 |
+
|
243 |
+
gif = io.BytesIO()
|
244 |
+
|
245 |
+
options = {
|
246 |
+
"fp": gif,
|
247 |
+
"format": "GIF",
|
248 |
+
"save_all": True,
|
249 |
+
"append_images": frames[1:],
|
250 |
+
"disposal": 3,
|
251 |
+
"duration": duration,
|
252 |
+
"loop": 0,
|
253 |
+
}
|
254 |
+
|
255 |
+
frames[0].save(**options)
|
256 |
+
gifs[side] = Image.open(gif)
|
257 |
+
|
258 |
+
if save:
|
259 |
+
timestamp = timestamp or int(time.time())
|
260 |
+
filepath = os.path.join(dir, f"{timestamp}_{text}_{side}.gif")
|
261 |
+
filepaths.append(filepath)
|
262 |
+
|
263 |
+
options.update({"fp": filepath})
|
264 |
+
frames[0].save(**options)
|
265 |
+
|
266 |
+
return gifs, filepaths
|
requirements.txt
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-i https://pypi.org/simple
|
2 |
+
accelerate==0.14.0
|
3 |
+
aiohttp==3.8.3
|
4 |
+
aiosignal==1.3.1
|
5 |
+
anyio==3.6.2
|
6 |
+
async-timeout==4.0.2
|
7 |
+
attrs==22.1.0
|
8 |
+
bcrypt==4.0.1
|
9 |
+
certifi==2022.9.24
|
10 |
+
cffi==1.15.1
|
11 |
+
charset-normalizer==2.1.1
|
12 |
+
click==8.1.3
|
13 |
+
contourpy==1.0.6
|
14 |
+
cryptography==38.0.3
|
15 |
+
cycler==0.11.0
|
16 |
+
diffusers==0.7.2
|
17 |
+
fastapi==0.86.0
|
18 |
+
ffmpy==0.3.0
|
19 |
+
filelock==3.8.0
|
20 |
+
fonttools==4.38.0
|
21 |
+
frozenlist==1.3.3
|
22 |
+
fsspec==2022.11.0
|
23 |
+
gradio==3.9.1
|
24 |
+
h11==0.12.0
|
25 |
+
httpcore==0.15.0
|
26 |
+
httpx==0.23.0
|
27 |
+
huggingface-hub==0.10.1
|
28 |
+
idna==3.4
|
29 |
+
importlib-metadata==5.0.0
|
30 |
+
jinja2==3.1.2
|
31 |
+
kiwisolver==1.4.4
|
32 |
+
linkify-it-py==1.0.3
|
33 |
+
markdown-it-py==2.1.0
|
34 |
+
markupsafe==2.1.1
|
35 |
+
matplotlib==3.6.2
|
36 |
+
mdit-py-plugins==0.3.1
|
37 |
+
mdurl==0.1.2
|
38 |
+
multidict==6.0.2
|
39 |
+
numpy==1.23.4
|
40 |
+
nvidia-cublas-cu11==11.10.3.66
|
41 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
42 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
43 |
+
nvidia-cudnn-cu11==8.5.0.96
|
44 |
+
orjson==3.8.1
|
45 |
+
packaging==21.3
|
46 |
+
pandas==1.5.1
|
47 |
+
paramiko==2.12.0
|
48 |
+
pillow==9.3.0
|
49 |
+
psutil==5.9.4
|
50 |
+
pycparser==2.21
|
51 |
+
pycryptodome==3.15.0
|
52 |
+
pydantic==1.10.2
|
53 |
+
pydub==0.25.1
|
54 |
+
pynacl==1.5.0
|
55 |
+
pyparsing==3.0.9
|
56 |
+
python-dateutil==2.8.2
|
57 |
+
python-multipart==0.0.5
|
58 |
+
pytz==2022.6
|
59 |
+
pyyaml==6.0
|
60 |
+
regex==2022.10.31
|
61 |
+
requests==2.28.1
|
62 |
+
rfc3986==1.5.0
|
63 |
+
scipy==1.9.3
|
64 |
+
setuptools==65.5.1
|
65 |
+
six==1.16.0
|
66 |
+
sniffio==1.3.0
|
67 |
+
starlette==0.20.4
|
68 |
+
tokenizers==0.13.2
|
69 |
+
torch==1.13.0
|
70 |
+
torchvision==0.14.0
|
71 |
+
tqdm==4.64.1
|
72 |
+
transformers==4.24.0
|
73 |
+
typing-extensions==4.4.0
|
74 |
+
uc-micro-py==1.0.1
|
75 |
+
urllib3==1.26.12
|
76 |
+
uvicorn==0.19.0
|
77 |
+
websockets==10.4
|
78 |
+
wheel==0.38.4
|
79 |
+
yarl==1.8.1
|
80 |
+
zipp==3.10.0
|