Spaces:
Running
on
Zero
Running
on
Zero
from typing import Callable | |
from PIL import Image | |
import gradio as gr | |
from v2 import V2UI | |
from diffusion import ImageGenerator, image_generation_config_ui | |
from output import UpsamplingOutput | |
from utils import ( | |
PEOPLE_TAGS, | |
gradio_copy_text, | |
COPY_ACTION_JS, | |
) | |
NORMALIZE_RATING_TAG = { | |
"sfw": "", | |
"general": "", | |
"sensitive": "sensitive", | |
"nsfw": "nsfw", | |
"questionable": "nsfw", | |
"explicit": "nsfw, explicit", | |
} | |
def example( | |
copyright: str, | |
character: str, | |
general: str, | |
rating: str, | |
aspect_ratio: str, | |
length: str, | |
identity: str, | |
image_size: str, | |
): | |
return [ | |
copyright, | |
character, | |
general, | |
rating, | |
aspect_ratio, | |
length, | |
identity, | |
image_size, | |
] | |
GRADIO_EXAMPLES = [ | |
example( | |
copyright="original", | |
character="", | |
general="1girl, solo, upper body, :d", | |
rating="general", | |
aspect_ratio="tall", | |
length="long", | |
identity="none", | |
image_size="768x1344", | |
), | |
example( | |
copyright="original", | |
character="", | |
general="1girl, solo, blue theme, limited palette", | |
rating="sfw", | |
aspect_ratio="ultra_wide", | |
length="long", | |
identity="lax", | |
image_size="1536x640", | |
), | |
example( | |
copyright="", | |
character="", | |
general="4girls", | |
rating="sfw", | |
aspect_ratio="tall", | |
length="very_long", | |
identity="lax", | |
image_size="768x1344", | |
), | |
example( | |
copyright="original", | |
character="", | |
general="1girl, solo, upper body, looking at viewer, profile picture", | |
rating="sfw", | |
aspect_ratio="square", | |
length="medium", | |
identity="none", | |
image_size="1024x1024", | |
), | |
example( | |
copyright="original", | |
character="", | |
general="1girl, solo, chibi, upper body, looking at viewer, simple background, limited palette, square", | |
rating="sfw", | |
aspect_ratio="square", | |
length="medium", | |
identity="none", | |
image_size="1024x1024", | |
), | |
example( | |
copyright="original", | |
character="", | |
general="1girl, full body, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, solo, yellow shirt, simple background, green background", | |
rating="sfw", | |
aspect_ratio="tall", | |
length="very_long", | |
identity="strict", | |
image_size="768x1344", | |
), | |
example( | |
copyright="", | |
character="", | |
general="no humans, scenery, spring (season)", | |
rating="general", | |
aspect_ratio="ultra_wide", | |
length="medium", | |
identity="lax", | |
image_size="1536x640", | |
), | |
example( | |
copyright="", | |
character="", | |
general="no humans, cyberpunk, city, cityscape, building, neon lights, pixel art", | |
rating="general", | |
aspect_ratio="ultra_wide", | |
length="medium", | |
identity="lax", | |
image_size="1536x640", | |
), | |
example( | |
copyright="sousou no frieren", | |
character="frieren", | |
general="1girl, solo", | |
rating="general", | |
aspect_ratio="tall", | |
length="long", | |
identity="lax", | |
image_size="768x1344", | |
), | |
example( | |
copyright="honkai: star rail", | |
character="firefly (honkai: star rail)", | |
general="1girl, solo", | |
rating="sfw", | |
aspect_ratio="tall", | |
length="medium", | |
identity="lax", | |
image_size="768x1344", | |
), | |
example( | |
copyright="honkai: star rail", | |
character="silver wolf (honkai: star rail)", | |
general="1girl, solo, annoyed", | |
rating="sfw", | |
aspect_ratio="tall", | |
length="long", | |
identity="lax", | |
image_size="768x1344", | |
), | |
example( | |
copyright="chuunibyou demo koi ga shitai!", | |
character="takanashi rikka", | |
general="1girl, solo", | |
rating="sfw", | |
aspect_ratio="ultra_tall", | |
length="medium", | |
identity="lax", | |
image_size="640x1536", | |
), | |
] | |
def animagine_xl_v3_1(output: UpsamplingOutput): | |
# separate people tags (e.g. 1girl) | |
people_tags = [] | |
other_general_tags = [] | |
for tag in output.general_tags.split(","): | |
tag = tag.strip() | |
if tag in PEOPLE_TAGS: | |
people_tags.append(tag) | |
else: | |
other_general_tags.append(tag) | |
return ", ".join( | |
[ | |
part.strip() | |
for part in [ | |
*people_tags, | |
output.character_tags, | |
output.copyright_tags, | |
*other_general_tags, | |
output.upsampled_tags, | |
NORMALIZE_RATING_TAG[output.rating_tag], | |
] | |
if part.strip() != "" | |
] | |
) | |
def elapsed_time_format(elapsed_time: float) -> str: | |
return f"Elapsed: {elapsed_time:.2f} seconds" | |
def parse_upsampling_output( | |
upsampler: Callable[..., UpsamplingOutput], | |
): | |
def _parse_upsampling_output(*args) -> tuple[str, str, dict, dict]: | |
output = upsampler(*args) | |
print(output) | |
return ( | |
animagine_xl_v3_1(output), | |
elapsed_time_format(output.elapsed_time), | |
gr.update( | |
interactive=True, | |
), | |
gr.update( | |
interactive=True, | |
), | |
) | |
return _parse_upsampling_output | |
def description_ui(): | |
gr.Markdown( | |
""" | |
# Danbooru Tags Transformer V2 Demo | |
Models: | |
- [dart-v2-moe-sft](https://huggingface.co/p1atdev/dart-v2-moe-sft) (Mixtral architecture) | |
- [dart-v2-sft](https://huggingface.co/p1atdev/dart-v2-sft) (Mistral architecture) | |
- [Animagine XL v3.1](https://huggingface.co/cagliostrolab/animagine-xl-3.1) (Image generation model) | |
""" | |
) | |
def main(): | |
v2 = V2UI() | |
print("Loading diffusion model...") | |
image_generator = ImageGenerator() | |
print("Loaded.") | |
with gr.Blocks() as ui: | |
description_ui() | |
with gr.Row(): | |
with gr.Column(): | |
v2.ui() | |
with gr.Column(): | |
generate_btn = gr.Button(value="Generate tags", variant="primary") | |
with gr.Group(): | |
output_text = gr.TextArea(label="Output tags", interactive=False) | |
copy_btn = gr.Button( | |
value="Copy to clipboard", | |
interactive=False, | |
) | |
elapsed_time_md = gr.Markdown(label="Elapsed time", value="") | |
generate_image_btn = gr.Button( | |
value="Generate image with this prompt!", | |
interactive=False, | |
) | |
accordion, image_generation_config_components = ( | |
image_generation_config_ui() | |
) | |
output_image = gr.Gallery( | |
label="Generated image", | |
show_label=True, | |
columns=1, | |
preview=True, | |
visible=True, | |
) | |
gr.Examples( | |
examples=GRADIO_EXAMPLES, | |
inputs=[ | |
*v2.get_inputs()[1:8], | |
image_generation_config_components[0], # image_size | |
], | |
) | |
generate_btn.click( | |
parse_upsampling_output(v2.on_generate), | |
inputs=[ | |
*v2.get_inputs(), | |
], | |
outputs=[output_text, elapsed_time_md, copy_btn, generate_image_btn], | |
) | |
copy_btn.click(gradio_copy_text, inputs=[output_text], js=COPY_ACTION_JS) | |
generate_image_btn.click( | |
image_generator.generate, | |
inputs=[output_text, *image_generation_config_components], | |
outputs=[output_image], | |
) | |
ui.launch() | |
if __name__ == "__main__": | |
main() | |