MoRanYue's picture
allow pickle
2755939
import gradio as gr
from diffusers import DPMSolverMultistepScheduler, AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
import torch
from tqdm.auto import tqdm
from time import time
from PIL import Image
vae = AutoencoderKL.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="vae", allow_pickle=True)
tokenizer = CLIPTokenizer.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="tokenizer", allow_pickle=True)
textEncoder = CLIPTextModel.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="text_encoder", allow_pickle=True)
unet = UNet2DConditionModel.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="unet", allow_pickle=True)
scheduler = DPMSolverMultistepScheduler.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="scheduler", allow_pickle=True)
torchDevice = "cuda"
vae.to(torchDevice)
textEncoder.to(torchDevice)
unet.to(torchDevice)
def generate(prompt: str, negativePrompt: str, steps: int, cfg: float, seed: int, randomized: bool, width: int, height: int):
generator = torch.manual_seed(time())
if randomized:
seed = torch.randint(10000, 9223372036854776000, (1,))[0]
batchSize = len(prompt)
textInput = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
textEmbeddings = textEncoder(textInput.input_ids.to(torchDevice), attention_mask=textInput.attention_mask.to(torchDevice))[0]
maxLength = textInput.input_ids.shape[-1]
unconditionedInput = tokenizer([""] * batchSize, padding="max_length", max_length=maxLength, return_tensors="pt")
unconditionedEmbeddings = textEncoder(unconditionedInput.input_ids.to(torchDevice))[0]
textEmbeddings = torch.cat([unconditionedEmbeddings, textEmbeddings])
latents = torch.randn((batchSize, unet.config.in_channels, height // 8, width // 8), generator=generator, device=torchDevice)
latents = latents * scheduler.init_noise_sigma
scheduler.set_timesteps(steps)
for t in tqdm(scheduler.timesteps):
latentModelInput = torch.cat([latents] * 2)
latentModelInput = scheduler.scale_model_input(latentModelInput, timestep=t)
with torch.no_grad():
noisePred = unet(latentModelInput, t, encoder_hidden_states=textEmbeddings).sample
unconditionedNoisePred, noisePredText = noisePred.chunk(2)
noisePred = unconditionedNoisePred + cfg * (noisePredText - unconditionedNoisePred)
latents = scheduler.step(noisePred, t, latents).prev_sample
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1).squeeze()
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
images = (image * 255).round().astype("uint8")
return Image.fromarray(images)
interface = gr.Interface(fn=generate, inputs=[
gr.Textbox(lines=3, placeholder="Prompt is here...", label="Prompt"),
gr.Textbox(lines=3, placeholder="Negative prompt is here...", label="Negative Prompt"),
gr.Slider(0, 1000, step=1, label="Steps", value=20),
gr.Slider(0, 50, step=0.1, label="CFG Scale", value=8),
gr.Number(label="Seed", value=0),
gr.Checkbox(label="Randomize Seed", value=True),
gr.Slider(256, 999999, step=64, label="Width", value=512),
gr.Slider(256, 999999, step=64, label="Height", value=512),
], outputs="image")
if __name__ == "__main__":
interface.launch()