p1atdev's picture
fix: model link
9fa17f1
from typing import Callable
from PIL import Image
import gradio as gr
from v2 import V2UI
from diffusion import ImageGenerator, image_generation_config_ui
from output import UpsamplingOutput
from utils import (
PEOPLE_TAGS,
gradio_copy_text,
COPY_ACTION_JS,
)
NORMALIZE_RATING_TAG = {
"sfw": "",
"general": "",
"sensitive": "sensitive",
"nsfw": "nsfw",
"questionable": "nsfw",
"explicit": "nsfw, explicit",
}
def example(
copyright: str,
character: str,
general: str,
rating: str,
aspect_ratio: str,
length: str,
identity: str,
image_size: str,
):
return [
copyright,
character,
general,
rating,
aspect_ratio,
length,
identity,
image_size,
]
GRADIO_EXAMPLES = [
example(
copyright="original",
character="",
general="1girl, solo, upper body, :d",
rating="general",
aspect_ratio="tall",
length="long",
identity="none",
image_size="768x1344",
),
example(
copyright="original",
character="",
general="1girl, solo, blue theme, limited palette",
rating="sfw",
aspect_ratio="ultra_wide",
length="long",
identity="lax",
image_size="1536x640",
),
example(
copyright="",
character="",
general="4girls",
rating="sfw",
aspect_ratio="tall",
length="very_long",
identity="lax",
image_size="768x1344",
),
example(
copyright="original",
character="",
general="1girl, solo, upper body, looking at viewer, profile picture",
rating="sfw",
aspect_ratio="square",
length="medium",
identity="none",
image_size="1024x1024",
),
example(
copyright="original",
character="",
general="1girl, solo, chibi, upper body, looking at viewer, simple background, limited palette, square",
rating="sfw",
aspect_ratio="square",
length="medium",
identity="none",
image_size="1024x1024",
),
example(
copyright="original",
character="",
general="1girl, full body, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, solo, yellow shirt, simple background, green background",
rating="sfw",
aspect_ratio="tall",
length="very_long",
identity="strict",
image_size="768x1344",
),
example(
copyright="",
character="",
general="no humans, scenery, spring (season)",
rating="general",
aspect_ratio="ultra_wide",
length="medium",
identity="lax",
image_size="1536x640",
),
example(
copyright="",
character="",
general="no humans, cyberpunk, city, cityscape, building, neon lights, pixel art",
rating="general",
aspect_ratio="ultra_wide",
length="medium",
identity="lax",
image_size="1536x640",
),
example(
copyright="sousou no frieren",
character="frieren",
general="1girl, solo",
rating="general",
aspect_ratio="tall",
length="long",
identity="lax",
image_size="768x1344",
),
example(
copyright="honkai: star rail",
character="firefly (honkai: star rail)",
general="1girl, solo",
rating="sfw",
aspect_ratio="tall",
length="medium",
identity="lax",
image_size="768x1344",
),
example(
copyright="honkai: star rail",
character="silver wolf (honkai: star rail)",
general="1girl, solo, annoyed",
rating="sfw",
aspect_ratio="tall",
length="long",
identity="lax",
image_size="768x1344",
),
example(
copyright="chuunibyou demo koi ga shitai!",
character="takanashi rikka",
general="1girl, solo",
rating="sfw",
aspect_ratio="ultra_tall",
length="medium",
identity="lax",
image_size="640x1536",
),
]
def animagine_xl_v3_1(output: UpsamplingOutput):
# separate people tags (e.g. 1girl)
people_tags = []
other_general_tags = []
for tag in output.general_tags.split(","):
tag = tag.strip()
if tag in PEOPLE_TAGS:
people_tags.append(tag)
else:
other_general_tags.append(tag)
return ", ".join(
[
part.strip()
for part in [
*people_tags,
output.character_tags,
output.copyright_tags,
*other_general_tags,
output.upsampled_tags,
NORMALIZE_RATING_TAG[output.rating_tag],
]
if part.strip() != ""
]
)
def elapsed_time_format(elapsed_time: float) -> str:
return f"Elapsed: {elapsed_time:.2f} seconds"
def parse_upsampling_output(
upsampler: Callable[..., UpsamplingOutput],
):
def _parse_upsampling_output(*args) -> tuple[str, str, dict, dict]:
output = upsampler(*args)
print(output)
return (
animagine_xl_v3_1(output),
elapsed_time_format(output.elapsed_time),
gr.update(
interactive=True,
),
gr.update(
interactive=True,
),
)
return _parse_upsampling_output
def description_ui():
gr.Markdown(
"""
# Danbooru Tags Transformer V2 Demo
Models:
- [dart-v2-moe-sft](https://huggingface.co/p1atdev/dart-v2-moe-sft) (Mixtral architecture)
- [dart-v2-sft](https://huggingface.co/p1atdev/dart-v2-sft) (Mistral architecture)
- [Animagine XL v3.1](https://huggingface.co/cagliostrolab/animagine-xl-3.1) (Image generation model)
"""
)
def main():
v2 = V2UI()
print("Loading diffusion model...")
image_generator = ImageGenerator()
print("Loaded.")
with gr.Blocks() as ui:
description_ui()
with gr.Row():
with gr.Column():
v2.ui()
with gr.Column():
generate_btn = gr.Button(value="Generate tags", variant="primary")
with gr.Group():
output_text = gr.TextArea(label="Output tags", interactive=False)
copy_btn = gr.Button(
value="Copy to clipboard",
interactive=False,
)
elapsed_time_md = gr.Markdown(label="Elapsed time", value="")
generate_image_btn = gr.Button(
value="Generate image with this prompt!",
interactive=False,
)
accordion, image_generation_config_components = (
image_generation_config_ui()
)
output_image = gr.Gallery(
label="Generated image",
show_label=True,
columns=1,
preview=True,
visible=True,
)
gr.Examples(
examples=GRADIO_EXAMPLES,
inputs=[
*v2.get_inputs()[1:8],
image_generation_config_components[0], # image_size
],
)
generate_btn.click(
parse_upsampling_output(v2.on_generate),
inputs=[
*v2.get_inputs(),
],
outputs=[output_text, elapsed_time_md, copy_btn, generate_image_btn],
)
copy_btn.click(gradio_copy_text, inputs=[output_text], js=COPY_ACTION_JS)
generate_image_btn.click(
image_generator.generate,
inputs=[output_text, *image_generation_config_components],
outputs=[output_image],
)
ui.launch()
if __name__ == "__main__":
main()