Stick2Body / app.py
tori29umai's picture
Add application file
dce20cd
raw
history blame
3.77 kB
import spaces
import gradio as gr
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, ControlNetModel, AutoencoderKL
from PIL import Image
import os
import time
from utils.dl_utils import dl_cn_model, dl_cn_config, dl_lora_model
from utils.image_utils import resize_image_aspect_ratio, base_generation
from utils.prompt_utils import remove_duplicates
path = os.getcwd()
cn_dir = f"{path}/controlnet"
lora_dir = f"{path}/lora"
os.makedirs(cn_dir, exist_ok=True)
os.makedirs(lora_dir, exist_ok=True)
dl_cn_model(cn_dir)
dl_cn_config(cn_dir)
dl_lora_model(lora_dir)
def load_model(lora_dir, cn_dir):
dtype = torch.float16
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
"cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
pipe.load_lora_weights(lora_dir, weight_name="Fixhands_anime_bdsqlsz_V1.safetensors")
return pipe
@spaces.GPU(duration=120)
def predict(input_image_path, prompt, negative_prompt, controlnet_scale):
pipe = load_model(lora_dir, cn_dir)
input_image = Image.open(input_image_path)
base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB")
resize_image = resize_image_aspect_ratio(input_image)
resize_base_image = resize_image_aspect_ratio(base_image)
generator = torch.manual_seed(0)
last_time = time.time()
prompt = "masterpiece, best quality, simple background, white background, bald, nude, " + prompt
prompt = remove_duplicates(prompt)
print(prompt)
output_image = pipe(
image=resize_base_image,
control_image=resize_image,
strength=1.0,
prompt=prompt,
negative_prompt = negative_prompt,
controlnet_conditioning_scale=float(controlnet_scale),
generator=generator,
num_inference_steps=30,
eta=1.0,
).images[0]
print(f"Time taken: {time.time() - last_time}")
output_image = output_image.resize(input_image.size, Image.LANCZOS)
return output_image
class Img2Img:
def __init__(self):
self.demo = self.layout()
self.tagger_model = None
self.input_image_path = None
self.canny_image = None
def layout(self):
css = """
#intro{
max-width: 32rem;
text-align: center;
margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column():
self.input_image_path = gr.Image(label="input_image", type='filepath')
self.prompt = gr.Textbox(label="prompt", lines=3)
self.negative_prompt = gr.Textbox(label="negative_prompt", lines=3, value="nsfw, nipples, bad anatomy, liquid fingers, low quality, worst quality, out of focus, ugly, error, jpeg artifacts, lowers, blurry, bokeh")
self.controlnet_scale = gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.01, label="Stick_fidelity")
generate_button = gr.Button("generate")
with gr.Column():
self.output_image = gr.Image(type="pil", label="output_image")
generate_button.click(
fn=predict,
inputs=[self.input_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
outputs=self.output_image
)
return demo
img2img = Img2Img()
img2img.demo.launch(share=True)