Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from diffusers import DDIMScheduler, DiffusionPipeline | |
from masactrl.diffuser_utils import MasaCtrlPipeline | |
from masactrl.masactrl_utils import AttentionBase, regiter_attention_editor_diffusers | |
from masactrl.masactrl import MutualSelfAttentionControl | |
from pytorch_lightning import seed_everything | |
import os | |
import re | |
# 初始化设备和模型 | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) | |
model = DiffusionPipeline.from_pretrained("svjack/GenshinImpact_XL_Base", scheduler=scheduler).to(device) | |
def pathify(s): | |
return re.sub(r'[^a-zA-Z0-9]', '_', s.lower()) | |
def consistent_synthesis(prompt1, prompt2, guidance_scale, seed, starting_step, starting_layer): | |
seed_everything(seed) | |
# 创建输出目录 | |
out_dir_ori = os.path.join("masactrl_exp", pathify(prompt2)) | |
os.makedirs(out_dir_ori, exist_ok=True) | |
prompts = [prompt1, prompt2] | |
# 初始化噪声图 | |
start_code = torch.randn([1, 4, 128, 128], device=device) | |
start_code = start_code.expand(len(prompts), -1, -1, -1) | |
# 推理没有 MasaCtrl 的图像 | |
editor = AttentionBase() | |
regiter_attention_editor_diffusers(model, editor) | |
image_ori = model(prompts, latents=start_code, guidance_scale=guidance_scale).images | |
images = [] | |
# 劫持注意力模块 | |
editor = MutualSelfAttentionControl(starting_step, starting_layer, model_type="SDXL") | |
regiter_attention_editor_diffusers(model, editor) | |
# 推理带 MasaCtrl 的图像 | |
image_masactrl = model(prompts, latents=start_code, guidance_scale=guidance_scale).images | |
sample_count = len(os.listdir(out_dir_ori)) | |
out_dir = os.path.join(out_dir_ori, f"sample_{sample_count}") | |
os.makedirs(out_dir, exist_ok=True) | |
image_ori[0].save(os.path.join(out_dir, f"source_step{starting_step}_layer{starting_layer}.png")) | |
image_ori[1].save(os.path.join(out_dir, f"without_step{starting_step}_layer{starting_layer}.png")) | |
image_masactrl[-1].save(os.path.join(out_dir, f"masactrl_step{starting_step}_layer{starting_layer}.png")) | |
with open(os.path.join(out_dir, f"prompts.txt"), "w") as f: | |
for p in prompts: | |
f.write(p + "\n") | |
f.write(f"seed: {seed}\n") | |
f.write(f"starting_step: {starting_step}\n") | |
f.write(f"starting_layer: {starting_layer}\n") | |
print("Synthesized images are saved in", out_dir) | |
return [image_ori[0], image_ori[1], image_masactrl[-1]] | |
def create_demo_synthesis(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# **Genshin Impact XL MasaCtrl Image Synthesis**") # 添加标题 | |
gr.Markdown("## **Input Settings**") | |
with gr.Row(): | |
with gr.Column(): | |
prompt1 = gr.Textbox(label="Prompt 1", value="solo,ZHONGLI(genshin impact),1boy,highres,") | |
prompt2 = gr.Textbox(label="Prompt 2", value="solo,ZHONGLI drink tea use chinese cup (genshin impact),1boy,highres,") | |
with gr.Row(): | |
starting_step = gr.Slider(label="Starting Step", minimum=0, maximum=999, value=4, step=1) | |
starting_layer = gr.Slider(label="Starting Layer", minimum=0, maximum=999, value=64, step=1) | |
run_btn = gr.Button("Run") | |
with gr.Column(): | |
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1) | |
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=42, step=1) | |
gr.Markdown("## **Output**") | |
with gr.Row(): | |
image_source = gr.Image(label="Source Image") | |
image_without_masactrl = gr.Image(label="Image without MasaCtrl") | |
image_with_masactrl = gr.Image(label="Image with MasaCtrl") | |
inputs = [prompt1, prompt2, guidance_scale, seed, starting_step, starting_layer] | |
run_btn.click(consistent_synthesis, inputs, [image_source, image_without_masactrl, image_with_masactrl]) | |
gr.Examples( | |
[ | |
["solo,ZHONGLI(genshin impact),1boy,highres,", "solo,ZHONGLI drink tea use chinese cup (genshin impact),1boy,highres,", 42, 4, 64], | |
["solo,KAMISATO AYATO(genshin impact),1boy,highres,", "solo,KAMISATO AYATO smiling (genshin impact),1boy,highres,", 42, 4, 55] | |
], | |
[prompt1, prompt2, seed, starting_step, starting_layer], | |
) | |
return demo | |
if __name__ == "__main__": | |
demo_synthesis = create_demo_synthesis() | |
demo_synthesis.launch(share = True) | |