ysharma's picture
ysharma HF staff
Update app.py
092fcaa
raw
history blame
4.35 kB
import gradio as gr
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
from diffusers.utils import load_image
from transformers import DPTImageProcessor, DPTForDepthEstimation
import torch
import mediapy
import sa_handler
import pipeline_calls
# init models
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
feature_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0",
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
).to("cuda")
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet,
vae=vae,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
).to("cuda")
pipeline.enable_model_cpu_offload()
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
share_layer_norm=False,
share_attention=True,
adain_queries=True,
adain_keys=True,
adain_values=False,
)
handler = sa_handler.Handler(pipeline)
handler.register(sa_args, )
# get depth maps
def get_depth_maps(image):
image = load_image(image) #("./example_image/train.png")
depth_image1 = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator)
#depth_image2 = load_image("./example_image/sun.png").resize((1024, 1024))
#mediapy.show_images([depth_image1, depth_image2])
return depth_image1 #[depth_image1, depth_image2]
# run ControlNet depth with StyleAligned
def style_aligned_controlnet(reference_prompt, target_prompt, image)
#reference_prompt = "a poster in flat design style"
#target_prompts = [target_prompts] #["a train in flat design style", "the sun in flat design style"]
controlnet_conditioning_scale = 0.8
num_images_per_prompt = 1 # adjust according to VRAM size
depth_map = get_depth_maps(image)
latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
#for deph_map, target_prompt in zip((depth_image1, depth_image2), target_prompts):
latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
images = pipeline_calls.controlnet_call(pipeline, [reference_prompt, target_prompt],
image=deph_map,
num_inference_steps=50,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_images_per_prompt=num_images_per_prompt,
latents=latents)
print(f"images -{images}")
return images[0]
#mediapy.show_images([images[0], deph_map] + images[1:], titles=["reference", "depth"] + [f'result {i}' for i in range(1, len(images))])
# run StyleAligned
sets_of_prompts = [
"a toy train. macro photo. 3d game asset",
"a toy airplane. macro photo. 3d game asset",
"a toy bicycle. macro photo. 3d game asset",
"a toy car. macro photo. 3d game asset",
"a toy boat. macro photo. 3d game asset",
]
with gr.Blocks() as demo:
with gr.Row(variant='panel'):
with gr.Group():
gr.Markdown("### <center>Reference Prompt and Image</center>")
ref_prompt = gr.Textbox(label="Enter a Prompt describing the reference image", placeholder='a photo of <object> in <style name> style')
depth_map = gr.Image(label="Upload the image to get Depth Map", )
with gr.Group():
gr.Markdown("### <center>Prompt for generation and generated Image</center>")
prompt = gr.Textbox(label="Enter a Prompt", placeholder='a photo of <object> in <style name> style')
output = gr.Image(label="Style-Aligned ControlNet",type='pil')
btn = gr.Button("Generate", size='sm')
btn.click(fn=greet, inputs=[ref_prompt, prompt, depth_map], outputs=output, api_name="style_aligned_controlnet")
demo.launch()