MoRanYue commited on
Commit
1feb9c4
1 Parent(s): f3949fd
KBlueLeaf-Kohaku-XL-Epsilon-rev3.code-workspace ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "folders": [
3
+ {
4
+ "path": "."
5
+ }
6
+ ],
7
+ "settings": {}
8
+ }
app.py CHANGED
@@ -1,3 +1,66 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- gr.load("models/KBlueLeaf/Kohaku-XL-Epsilon-rev3").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from diffusers import DPMSolverMultistepScheduler, AutoencoderKL, UNet2DConditionModel
3
+ from transformers import CLIPTextModel, CLIPTokenizer
4
+ import torch
5
+ from tqdm.auto import tqdm
6
+ from time import time
7
+ from PIL import Image
8
 
9
+ vae = AutoencoderKL.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="vae")
10
+ tokenizer = CLIPTokenizer.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="tokenizer")
11
+ textEncoder = CLIPTextModel.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="text_encoder")
12
+ unet = UNet2DConditionModel.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="unet")
13
+ scheduler = DPMSolverMultistepScheduler.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="scheduler")
14
+
15
+ torchDevice = "cuda"
16
+ vae.to(torchDevice)
17
+ textEncoder.to(torchDevice)
18
+ unet.to(torchDevice)
19
+
20
+ def generate(prompt: str, negativePrompt: str, steps: int, cfg: float, seed: int, randomized: bool, width: int, height: int):
21
+ generator = torch.manual_seed(time())
22
+ if randomized:
23
+ seed = torch.randint(10000, 9223372036854776000, (1,))[0]
24
+ batchSize = len(prompt)
25
+ textInput = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
26
+ with torch.no_grad():
27
+ textEmbeddings = textEncoder(textInput.input_ids.to(torchDevice), attention_mask=textInput.attention_mask.to(torchDevice))[0]
28
+ maxLength = textInput.input_ids.shape[-1]
29
+ unconditionedInput = tokenizer([""] * batchSize, padding="max_length", max_length=maxLength, return_tensors="pt")
30
+ unconditionedEmbeddings = textEncoder(unconditionedInput.input_ids.to(torchDevice))[0]
31
+ textEmbeddings = torch.cat([unconditionedEmbeddings, textEmbeddings])
32
+
33
+ latents = torch.randn((batchSize, unet.config.in_channels, height // 8, width // 8), generator=generator, device=torchDevice)
34
+ latents = latents * scheduler.init_noise_sigma
35
+
36
+ scheduler.set_timesteps(steps)
37
+ for t in tqdm(scheduler.timesteps):
38
+ latentModelInput = torch.cat([latents] * 2)
39
+ latentModelInput = scheduler.scale_model_input(latentModelInput, timestep=t)
40
+ with torch.no_grad():
41
+ noisePred = unet(latentModelInput, t, encoder_hidden_states=textEmbeddings).sample
42
+ unconditionedNoisePred, noisePredText = noisePred.chunk(2)
43
+ noisePred = unconditionedNoisePred + cfg * (noisePredText - unconditionedNoisePred)
44
+ latents = scheduler.step(noisePred, t, latents).prev_sample
45
+
46
+ latents = 1 / 0.18215 * latents
47
+ with torch.no_grad():
48
+ image = vae.decode(latents).sample
49
+ image = (image / 2 + 0.5).clamp(0, 1).squeeze()
50
+ image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
51
+ images = (image * 255).round().astype("uint8")
52
+ return Image.fromarray(images)
53
+
54
+ interface = gr.Interface(fn=generate, inputs=[
55
+ gr.Textbox(lines=3, placeholder="Prompt is here...", label="Prompt"),
56
+ gr.Textbox(lines=3, placeholder="Negative prompt is here...", label="Negative Prompt"),
57
+ gr.Slider(0, 1000, step=1, label="Steps", value=20),
58
+ gr.Slider(0, 50, step=0.1, label="CFG Scale", value=8),
59
+ gr.Number(label="Seed", value=0),
60
+ gr.Checkbox(label="Randomize Seed", value=True),
61
+ gr.Slider(256, 999999, step=64, label="Width", value=512),
62
+ gr.Slider(256, 999999, step=64, label="Height", value=512),
63
+ ], outputs="image")
64
+
65
+ if __name__ == "__main__":
66
+ interface.launch()