|
prod = False |
|
port = 8080 |
|
show_options = False |
|
if prod: |
|
port = 8081 |
|
|
|
|
|
import os |
|
import gc |
|
import random |
|
import time |
|
import gradio as gr |
|
import numpy as np |
|
import imageio |
|
from huggingface_hub import HfApi |
|
import torch |
|
import spaces |
|
from PIL import Image |
|
from diffusers import ( |
|
ControlNetModel, |
|
DPMSolverMultistepScheduler, |
|
StableDiffusionControlNetPipeline, |
|
|
|
) |
|
from controlnet_aux_local import NormalBaeDetector |
|
|
|
from diffusers.models.attention_processor import AttnProcessor2_0 |
|
MAX_SEED = np.iinfo(np.int32).max |
|
API_KEY = os.environ.get("API_KEY", None) |
|
|
|
print("CUDA version:", torch.version.cuda) |
|
print("loading everything") |
|
compiled = False |
|
api = HfApi() |
|
|
|
class Preprocessor: |
|
MODEL_ID = "lllyasviel/Annotators" |
|
|
|
def __init__(self): |
|
self.model = None |
|
self.name = "" |
|
|
|
def load(self, name: str) -> None: |
|
if name == self.name: |
|
return |
|
elif name == "NormalBae": |
|
print("Loading NormalBae") |
|
self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda") |
|
torch.cuda.empty_cache() |
|
self.name = name |
|
else: |
|
raise ValueError |
|
return |
|
|
|
def __call__(self, image: Image.Image, **kwargs) -> Image.Image: |
|
return self.model(image, **kwargs) |
|
|
|
|
|
|
|
|
|
model_id = "lllyasviel/control_v11p_sd15_normalbae" |
|
print("initializing controlnet") |
|
controlnet = ControlNetModel.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float16, |
|
attn_implementation="flash_attention_2", |
|
).to("cuda") |
|
|
|
|
|
scheduler = DPMSolverMultistepScheduler.from_pretrained( |
|
"ashllay/stable-diffusion-v1-5-archive", |
|
solver_order=2, |
|
subfolder="scheduler", |
|
use_karras_sigmas=True, |
|
final_sigmas_type="sigma_min", |
|
algorithm_type="sde-dpmsolver++", |
|
prediction_type="epsilon", |
|
thresholding=False, |
|
denoise_final=True, |
|
device_map="cuda", |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
|
|
base_model_url = "https://huggingface.co/broyang/hentaidigitalart_v20/blob/main/realcartoon3d_v15.safetensors" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print('loading pipe') |
|
pipe = StableDiffusionControlNetPipeline.from_single_file( |
|
base_model_url, |
|
safety_checker=None, |
|
controlnet=controlnet, |
|
scheduler=scheduler, |
|
|
|
torch_dtype=torch.float16, |
|
).to("cuda") |
|
|
|
print("loading preprocessor") |
|
preprocessor = Preprocessor() |
|
preprocessor.load("NormalBae") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipe.to("cuda") |
|
|
|
print("---------------Loaded controlnet pipeline---------------") |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
print(f"CUDA memory allocated: {torch.cuda.max_memory_allocated(device='cuda') / 1e9:.2f} GB") |
|
print("Model Compiled!") |
|
|
|
|
|
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: |
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
return seed |
|
|
|
def get_additional_prompt(): |
|
prompt = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed" |
|
top = ["tank top", "blouse", "button up shirt", "sweater", "corset top"] |
|
bottom = ["short skirt", "athletic shorts", "jean shorts", "pleated skirt", "short skirt", "leggings", "high-waisted shorts"] |
|
accessory = ["knee-high boots", "gloves", "Thigh-high stockings", "Garter belt", "choker", "necklace", "headband", "headphones"] |
|
return f"{prompt}, {random.choice(top)}, {random.choice(bottom)}, {random.choice(accessory)}, score_9" |
|
|
|
|
|
def get_prompt(prompt, additional_prompt): |
|
default = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed,tungsten white balance" |
|
|
|
default2 = f"hyperrealistic photography of {prompt},extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed" |
|
randomize = get_additional_prompt() |
|
|
|
|
|
lab_girl = "hyperrealistic photography, extremely detailed, shy assistant wearing minidress boots and gloves, laboratory background, score_9, 1girl" |
|
pet_play = "hyperrealistic photography, extremely detailed, playful, blush, glasses, collar, score_9, HDA_pet_play" |
|
bondage = "hyperrealistic photography, extremely detailed, submissive, glasses, score_9, HDA_Bondage" |
|
|
|
ahegao2 = "(invisiblebodypaint),rating_newd,HDA_Ahegao" |
|
athleisure = "hyperrealistic photography, extremely detailed, 1girl athlete, exhausted embarrassed sweaty,outdoors, ((athleisure clothing)), score_9" |
|
atompunk = "((atompunk world)), hyperrealistic photography, extremely detailed, short hair, bodysuit, glasses, neon cyberpunk background, score_9" |
|
maid = "hyperrealistic photography, extremely detailed, shy, blushing, score_9, pastel background, HDA_unconventional_maid" |
|
nundress = "hyperrealistic photography, extremely detailed, shy, blushing, fantasy background, score_9, HDA_NunDress" |
|
naked_hoodie = "hyperrealistic photography, extremely detailed, medium hair, cityscape, (neon lights), score_9, HDA_NakedHoodie" |
|
abg = "(1girl, asian body covered in words, words on body, tattoos of (words) on body),(masterpiece, best quality),medium breasts,(intricate details),unity 8k wallpaper,ultra detailed,(pastel colors),beautiful and aesthetic,see-through (clothes),detailed,solo" |
|
|
|
shibari2 = "octane render, highly detailed, volumetric, HDA_Shibari" |
|
|
|
if prompt == "": |
|
girls = [randomize, pet_play, bondage, lab_girl, athleisure, atompunk, maid, nundress, naked_hoodie, abg, shibari2] |
|
prompts_nsfw = [abg, shibari2, ahegao2] |
|
prompt = f"{random.choice(girls)}" |
|
prompt = default |
|
|
|
else: |
|
|
|
|
|
prompt = default2 |
|
|
|
return prompt |
|
|
|
css = """ |
|
h1, h2, h3 { |
|
text-align: center; |
|
display: block; |
|
} |
|
footer { |
|
visibility: hidden; |
|
} |
|
.gradio-container { |
|
max-width: 1100px !important; |
|
} |
|
.gr-image { |
|
display: flex; |
|
justify-content: center; |
|
align-items: center; |
|
width: 100%; |
|
height: 512px; |
|
overflow: hidden; |
|
} |
|
.gr-image img { |
|
width: 100%; |
|
height: 100%; |
|
object-fit: cover; |
|
object-position: center; |
|
} |
|
""" |
|
with gr.Blocks("bethecloud/storj_theme", css=css) as demo: |
|
|
|
with gr.Row(): |
|
with gr.Accordion("Advanced options", open=show_options, visible=show_options): |
|
num_images = gr.Slider( |
|
label="Images", minimum=1, maximum=4, value=1, step=1 |
|
) |
|
image_resolution = gr.Slider( |
|
label="Image resolution", |
|
minimum=256, |
|
maximum=1024, |
|
value=768, |
|
step=256, |
|
) |
|
preprocess_resolution = gr.Slider( |
|
label="Preprocess resolution", |
|
minimum=128, |
|
maximum=1024, |
|
value=768, |
|
step=1, |
|
) |
|
num_steps = gr.Slider( |
|
label="Number of steps", minimum=1, maximum=100, value=12, step=1 |
|
) |
|
guidance_scale = gr.Slider( |
|
label="Guidance scale", minimum=0.1, maximum=30.0, value=5.5, step=0.1 |
|
) |
|
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) |
|
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) |
|
a_prompt = gr.Textbox( |
|
label="Additional prompt", |
|
value = "" |
|
) |
|
n_prompt = gr.Textbox( |
|
label="Negative prompt", |
|
value="EasyNegativeV2, fcNeg, (badhandv4:1.4), chubby face, young kids, (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)", |
|
) |
|
|
|
|
|
with gr.Column(): |
|
prompt = gr.Textbox( |
|
label="Description", |
|
placeholder="Enter a description (optional)", |
|
) |
|
|
|
with gr.Row(equal_height=True): |
|
with gr.Column(scale=1, min_width=300): |
|
image = gr.Image( |
|
label="Input", |
|
sources=["upload"], |
|
show_label=True, |
|
mirror_webcam=True, |
|
type="pil", |
|
) |
|
|
|
with gr.Column(): |
|
run_button = gr.Button(value="Use this one", size="lg", visible=False) |
|
|
|
with gr.Column(scale=1, min_width=300): |
|
result = gr.Image( |
|
label="Output", |
|
interactive=False, |
|
type="pil", |
|
show_share_button= False, |
|
) |
|
|
|
with gr.Column(): |
|
use_ai_button = gr.Button(value="Use this one", size="lg", visible=False) |
|
config = [ |
|
image, |
|
prompt, |
|
a_prompt, |
|
n_prompt, |
|
num_images, |
|
image_resolution, |
|
preprocess_resolution, |
|
num_steps, |
|
guidance_scale, |
|
seed, |
|
] |
|
|
|
with gr.Row(): |
|
helper_text = gr.Markdown("## Tap and hold (on mobile) to save the image.", visible=True) |
|
|
|
|
|
@gr.on(triggers=[image.upload, prompt.submit, run_button.click], inputs=config, outputs=result, show_progress="minimal") |
|
def auto_process_image(image, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)): |
|
return process_image(image, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed) |
|
|
|
@gr.on(triggers=[use_ai_button.click], inputs=[result] + config, outputs=[image, result], show_progress="minimal") |
|
def submit(previous_result, image, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)): |
|
|
|
yield previous_result, gr.update() |
|
|
|
new_result = process_image(previous_result, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed) |
|
|
|
yield previous_result, new_result |
|
|
|
|
|
@gr.on(triggers=[image.upload, use_ai_button.click, run_button.click], inputs=None, outputs=[run_button, use_ai_button], show_progress="hidden") |
|
def turn_buttons_off(): |
|
return gr.update(visible=False), gr.update(visible=False) |
|
|
|
|
|
@gr.on(triggers=[result.change], inputs=None, outputs=[use_ai_button, run_button], show_progress="hidden") |
|
def turn_buttons_on(): |
|
return gr.update(visible=True), gr.update(visible=True) |
|
|
|
|
|
@spaces.GPU(duration=12) |
|
@torch.inference_mode() |
|
def process_image( |
|
image, |
|
prompt, |
|
a_prompt, |
|
n_prompt, |
|
num_images, |
|
image_resolution, |
|
preprocess_resolution, |
|
num_steps, |
|
guidance_scale, |
|
seed, |
|
progress=gr.Progress(track_tqdm=True) |
|
): |
|
preprocess_start = time.time() |
|
print("processing image") |
|
seed = random.randint(0, MAX_SEED) |
|
generator = torch.cuda.manual_seed(seed) |
|
preprocessor.load("NormalBae") |
|
control_image = preprocessor( |
|
image=image, |
|
image_resolution=image_resolution, |
|
detect_resolution=preprocess_resolution, |
|
) |
|
preprocess_time = time.time() - preprocess_start |
|
|
|
custom_prompt=str(get_prompt(prompt, a_prompt)) |
|
negative_prompt=str(n_prompt) |
|
print(f"{custom_prompt}") |
|
print(f"\n-------------------------Preprocess done in: {preprocess_time:.2f} seconds-------------------------") |
|
start = time.time() |
|
results = pipe( |
|
prompt=custom_prompt, |
|
negative_prompt=negative_prompt, |
|
guidance_scale=guidance_scale, |
|
num_images_per_prompt=num_images, |
|
num_inference_steps=num_steps, |
|
generator=generator, |
|
image=control_image, |
|
).images[0] |
|
print(f"\n-------------------------Inference done in: {time.time() - start:.2f} seconds-------------------------") |
|
torch.cuda.empty_cache() |
|
|
|
|
|
timestamp = int(time.time()) |
|
img_path = f"{timestamp}.jpg" |
|
results_path = f"{timestamp}_out.jpg" |
|
imageio.imsave(img_path, image) |
|
imageio.imsave(results_path, results) |
|
api.upload_file( |
|
path_or_fileobj=img_path, |
|
path_in_repo=img_path, |
|
repo_id="broyang/anime-ai-outputs2", |
|
repo_type="dataset", |
|
token=API_KEY, |
|
run_as_future=True, |
|
) |
|
api.upload_file( |
|
path_or_fileobj=results_path, |
|
path_in_repo=results_path, |
|
repo_id="broyang/anime-ai-outputs2", |
|
repo_type="dataset", |
|
token=API_KEY, |
|
run_as_future=True, |
|
) |
|
return results |
|
|
|
if prod: |
|
demo.queue(max_size=20).launch(server_name="localhost", server_port=port) |
|
else: |
|
demo.queue(api_open=False).launch(show_api=False) |