import os import io import re import time import random import torch from typing import Final, List, Optional, Tuple, cast from PIL import Image, ImageDraw, ImageEnhance from PIL.Image import Image as PILImage from diffusers import StableDiffusionPipeline model_id: Final = "Onodofthenorth/SD_PixelArt_SpriteSheet_Generator" pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16, cache_dir="cache" ) pipe = pipe.to("cuda") sprite_sides: Final = { "front": "PixelArtFSS", "right": "PixelArtRSS", "back": "PixelArtBSS", "left": "PixelArtLSS", } def torchGenerator(seed: Optional[int], max: int = 1024) -> Tuple[torch.Generator, int]: seed = seed or random.randrange(0, max) return torch.Generator("cuda").manual_seed(seed), seed def generate( prompt: str, sfw_retries: int = 1, seed: Optional[int] = None, ) -> PILImage: """ Generate a sprite image from a text description. Return a blank image if the model fails to generate a safe image. """ generator = torchGenerator(seed)[0] image: PILImage | None = None for _ in range(sfw_retries): pipe_output = pipe(prompt, generator=generator, width=512, height=512) image = pipe_output.images[0] if not pipe_output.nsfw_content_detected[0]: break rand_seed = seed while rand_seed == seed: print(f"Regenerating `{prompt}` with different seed.") rand_seed = random.randrange(0, 1024) generator = torchGenerator(rand_seed)[0] return cast(PILImage, image) def generate_sides( prompt: str, sfw_retries: int = 1, sides: dict[str, str] = sprite_sides ) -> Tuple[dict[str, PILImage], str]: """ Generate sprite images from a text description of different sides. If both left and right side specified, duplicate and flip left side as the right side """ print(f"Generating sprites for `{prompt}`") seed = random.randrange(0, 1024) sprites = {} # If both left and right side specified, duplicate and flip left side as the right side for side, label in sides.items(): if side == "right" and "left" in sides and "right" in sides: continue sprites[side] = generate(f"({prompt}) [nsfw] [photograph] {label}", sfw_retries, seed) if "left" in sides and "right" in sides: sprites["right"] = sprites["left"].transpose(Image.Transpose.FLIP_LEFT_RIGHT) return sprites, prompt def clean_sprite( image: PILImage, size: Tuple[int, int] = (192, 192), sharpness: float = 1.5, thresh: int = 128, rescaling: Optional[int] = None, ) -> PILImage: """ Process image to be more sprite-like. `rescale` will first scale down by value, then up to specified size. """ width, height = image.size sharpener = ImageEnhance.Sharpness(image) image = sharpener.enhance(sharpness) image = image.convert("RGBA") ImageDraw.floodfill(image, (0, 0), (255, 255, 255, 0), thresh=thresh) if type(rescaling) is int: image = image.resize( (int(width / rescaling), int(height / rescaling)), resample=Image.Resampling.NEAREST, ) image = image.resize(size, resample=Image.Resampling.NEAREST) return image def split_sprites(image: PILImage, size: Tuple[int, int] = (96, 96)) -> List[PILImage]: """Split sprite image into individual sides.""" width, height = image.size w, h = size # fmt: off frames = [ image.crop(( 0, int(h / 2), int(width / 4), int(height * 0.75), )), image.crop(( int(width / 4), int(h / 2), int(width / 4) * 2, int(height * 0.75), )), image.crop(( int(width / 4) * 2, int(h / 2), int(width / 4) * 3, int(height * 0.75), )), image.crop(( int(width / 4) * 3, int(h / 2), width, int(height * 0.75), )), ] # fmt: on new_canvas = Image.new("RGBA", size, (255, 255, 255, 0)) for i in range(len(frames)): canvas = new_canvas.copy() canvas.paste(frames[i], (int(w / 4), 0, int(w * 0.75), h)) frames[i] = canvas return frames def build_spritesheet( images: dict[str, PILImage], text: str = "sd_pixelart", sprite_size: Tuple[int, int] = (96, 96), dir: str = "output", save: bool = False, timestamp: Optional[int] = None, thresh: int = 128, ) -> Tuple[PILImage, str | None]: """ Build sprite sheet from sides. 1. Clean and scale each image 2. Split each image into individual frames 3. Create a new spritesheet canvas for all sides[frames] 4. Paste each individial frame onto canvas """ frames = {} width, height = sprite_size text = re.sub(r"[^\w()[\]_-]", "", text) filepath = None for side, image in images.items(): image = clean_sprite(image, (width * 2, height * 2), thresh=thresh) frames[side] = split_sprites(image, sprite_size) canvas = Image.new( "RGBA", (width * len(frames["front"]), height * len(frames)), (255, 255, 255, 0), ) for j in range(len(frames["front"])): for k, side in enumerate(frames): canvas.paste( frames[side][j], ( j * width, k * height, j * width + width, k * height + height, ), ) spritesheet = io.BytesIO() canvas.save(spritesheet, "PNG") if save: timestamp = timestamp or int(time.time()) filepath = os.path.join(dir, f"{timestamp}_{text}.png") canvas.save(filepath) return Image.open(spritesheet), filepath def build_gifs( images: dict[str, PILImage], text: str = "sd_spritesheet", dir: str = "output", duration: int | List[int] | Tuple[int, ...] = (300, 450, 300, 450), save: bool = False, timestamp: Optional[int] = None, thresh: int = 128, ) -> Tuple[dict[str, List[PILImage]], List[str] | None]: """Build animated GIFs from side frames.""" gifs = {} text = re.sub(r"[^\w()[\]_-]", "", text) filepaths = [] if save else None for side, image in images.items(): image = clean_sprite(image, thresh=thresh) frames = split_sprites(image) gif = io.BytesIO() options = { "fp": gif, "format": "GIF", "save_all": True, "append_images": frames[1:], "disposal": 3, "duration": duration, "loop": 0, } frames[0].save(**options) gifs[side] = Image.open(gif) if save: timestamp = timestamp or int(time.time()) filepath = os.path.join(dir, f"{timestamp}_{text}_{side}.gif") filepaths.append(filepath) options.update({"fp": filepath}) frames[0].save(**options) return gifs, filepaths