Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- app.py +175 -0
- sd_models.py +239 -0
app.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from rf_models import RF_model
|
4 |
+
from sd_models import SD_model
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from diffusers import StableDiffusionXLImg2ImgPipeline
|
11 |
+
import time
|
12 |
+
import copy
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
16 |
+
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
17 |
+
)
|
18 |
+
pipe = pipe.to("cuda")
|
19 |
+
|
20 |
+
global model
|
21 |
+
global base_model
|
22 |
+
global img
|
23 |
+
|
24 |
+
def set_model(model_id):
|
25 |
+
global model
|
26 |
+
if model_id == "InstaFlow-0.9B":
|
27 |
+
model = RF_model("./instaflow_09b.pt")
|
28 |
+
elif model_id == "InstaFlow-1.7B":
|
29 |
+
model = RF_model("./instaflow_17b.pt")
|
30 |
+
else:
|
31 |
+
raise NotImplementedError
|
32 |
+
print('Finished Loading Model!')
|
33 |
+
|
34 |
+
def set_base_model(model_id):
|
35 |
+
global base_model
|
36 |
+
if model_id == "runwayml/stable-diffusion-v1-5":
|
37 |
+
base_model = SD_model("runwayml/stable-diffusion-v1-5")
|
38 |
+
else:
|
39 |
+
raise NotImplementedError
|
40 |
+
print('Finished Loading Base Model!')
|
41 |
+
|
42 |
+
def set_new_latent_and_generate_new_image(seed, prompt, negative_prompt="", num_inference_steps=1, guidance_scale=0.0):
|
43 |
+
print('Generate with input seed')
|
44 |
+
global model
|
45 |
+
global img
|
46 |
+
seed = int(seed)
|
47 |
+
num_inference_steps = int(num_inference_steps)
|
48 |
+
guidance_scale = float(guidance_scale)
|
49 |
+
print(seed, num_inference_steps, guidance_scale)
|
50 |
+
|
51 |
+
t_s = time.time()
|
52 |
+
new_image = model.set_new_latent_and_generate_new_image(int(seed), prompt, negative_prompt, int(num_inference_steps), guidance_scale)
|
53 |
+
#print('time consumption:', time.time() - t_s)
|
54 |
+
inf_time = time.time() - t_s
|
55 |
+
|
56 |
+
img = copy.copy(new_image[0])
|
57 |
+
|
58 |
+
return new_image[0], inf_time
|
59 |
+
|
60 |
+
def set_new_latent_and_generate_new_image_with_base_model(seed, prompt, num_inference_steps=1, guidance_scale=0.0):
|
61 |
+
print('Generate with input seed')
|
62 |
+
global base_model
|
63 |
+
global img
|
64 |
+
negative_prompt=""
|
65 |
+
seed = int(seed)
|
66 |
+
num_inference_steps = int(num_inference_steps)
|
67 |
+
guidance_scale = float(guidance_scale)
|
68 |
+
print(seed, num_inference_steps, guidance_scale)
|
69 |
+
|
70 |
+
t_s = time.time()
|
71 |
+
new_image = base_model.set_new_latent_and_generate_new_image(int(seed), prompt, negative_prompt, int(num_inference_steps), guidance_scale)
|
72 |
+
#print('time consumption:', time.time() - t_s)
|
73 |
+
inf_time = time.time() - t_s
|
74 |
+
|
75 |
+
img = copy.copy(new_image[0])
|
76 |
+
|
77 |
+
return new_image[0], inf_time
|
78 |
+
|
79 |
+
|
80 |
+
def set_new_latent_and_generate_new_image_and_random_seed(seed, prompt, negative_prompt="", num_inference_steps=1, guidance_scale=0.0):
|
81 |
+
print('Generate with a random seed')
|
82 |
+
global model
|
83 |
+
global img
|
84 |
+
seed = np.random.randint(0, 2**32)
|
85 |
+
num_inference_steps = int(num_inference_steps)
|
86 |
+
guidance_scale = float(guidance_scale)
|
87 |
+
print(seed, num_inference_steps, guidance_scale)
|
88 |
+
|
89 |
+
t_s = time.time()
|
90 |
+
new_image = model.set_new_latent_and_generate_new_image(int(seed), prompt, negative_prompt, int(num_inference_steps), guidance_scale)
|
91 |
+
#print('time consumption:', time.time() - t_s)
|
92 |
+
inf_time = time.time() - t_s
|
93 |
+
|
94 |
+
img = copy.copy(new_image[0])
|
95 |
+
|
96 |
+
return new_image[0], seed, inf_time
|
97 |
+
|
98 |
+
|
99 |
+
def refine_image_512(prompt):
|
100 |
+
print('Refine with SDXL-Refiner (512)')
|
101 |
+
global img
|
102 |
+
|
103 |
+
t_s = time.time()
|
104 |
+
img = torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2)
|
105 |
+
img = img.permute(0, 2, 3, 1).squeeze(0).cpu().numpy()
|
106 |
+
new_image = pipe(prompt, image=img).images[0]
|
107 |
+
print('time consumption:', time.time() - t_s)
|
108 |
+
new_image = np.array(new_image) * 1.0 / 255.
|
109 |
+
|
110 |
+
img = new_image
|
111 |
+
|
112 |
+
return new_image
|
113 |
+
|
114 |
+
def refine_image_1024(prompt):
|
115 |
+
print('Refine with SDXL-Refiner (1024)')
|
116 |
+
global img
|
117 |
+
|
118 |
+
t_s = time.time()
|
119 |
+
img = torch.tensor(img).unsqueeze(0).permute(0, 3, 1, 2)
|
120 |
+
img = torch.nn.functional.interpolate(img, size=1024, mode='bilinear')
|
121 |
+
img = img.permute(0, 2, 3, 1).squeeze(0).cpu().numpy()
|
122 |
+
new_image = pipe(prompt, image=img).images[0]
|
123 |
+
print('time consumption:', time.time() - t_s)
|
124 |
+
new_image = np.array(new_image) * 1.0 / 255.
|
125 |
+
|
126 |
+
img = new_image
|
127 |
+
|
128 |
+
return new_image
|
129 |
+
|
130 |
+
set_model('InstaFlow-0.9B')
|
131 |
+
set_base_model("runwayml/stable-diffusion-v1-5")
|
132 |
+
|
133 |
+
with gr.Blocks() as gradio_gui:
|
134 |
+
gr.Markdown("Set Input Seed and Text Prompts Here")
|
135 |
+
with gr.Row():
|
136 |
+
with gr.Column(scale=0.4):
|
137 |
+
seed_input = gr.Textbox(value='101098274', label="Random Seed")
|
138 |
+
with gr.Column(scale=0.4):
|
139 |
+
prompt_input = gr.Textbox(value='A high-resolution photograph of a waterfall in autumn; muted tone', label="Prompt")
|
140 |
+
|
141 |
+
with gr.Row():
|
142 |
+
with gr.Column(scale=0.4):
|
143 |
+
with gr.Group():
|
144 |
+
gr.Markdown("Generation from InstaFlow-0.9B")
|
145 |
+
im = gr.Image()
|
146 |
+
|
147 |
+
gr.Markdown("Model ID: One-Step InstaFlow-0.9B")
|
148 |
+
inference_time_output = gr.Textbox(value='0.0', label='Inference Time with One-Step Model (Second)')
|
149 |
+
new_image_button = gr.Button(value="One-Step Generation with the Input Seed")
|
150 |
+
new_image_button.click(set_new_latent_and_generate_new_image, inputs=[seed_input, prompt_input], outputs=[im, inference_time_output])
|
151 |
+
|
152 |
+
next_image_button = gr.Button(value="One-Step Generation with a New Random Seed")
|
153 |
+
next_image_button.click(set_new_latent_and_generate_new_image_and_random_seed, inputs=[seed_input, prompt_input], outputs=[im, seed_input, inference_time_output])
|
154 |
+
|
155 |
+
refine_button_512 = gr.Button(value="Refine One-Step Generation with SDXL Refiner (Resolution: 512)")
|
156 |
+
refine_button_512.click(refine_image_512, inputs=[prompt_input], outputs=[im])
|
157 |
+
|
158 |
+
refine_button_1024 = gr.Button(value="Refine One-Step Generation with SDXL Refiner (Resolution: 1024)")
|
159 |
+
refine_button_1024.click(refine_image_1024, inputs=[prompt_input], outputs=[im])
|
160 |
+
|
161 |
+
with gr.Column(scale=0.4):
|
162 |
+
with gr.Group():
|
163 |
+
gr.Markdown("Generation from Stable Diffusion 1.5")
|
164 |
+
im_base = gr.Image()
|
165 |
+
|
166 |
+
gr.Markdown("Model ID: Multi-Step Stable Diffusion 1.5")
|
167 |
+
base_model_inference_time_output = gr.Textbox(value='0.0', label='Inference Time with Multi-Step Stable Diffusion (Second)')
|
168 |
+
|
169 |
+
num_inference_steps = gr.Textbox(value='25', label="Number of Inference Steps for Stable Diffusion")
|
170 |
+
guidance_scale = gr.Textbox(value='5.0', label="Guidance Scale for Stable Diffusion")
|
171 |
+
|
172 |
+
base_new_image_button = gr.Button(value="Multi-Step Generation with Stable Diffusion and the Input Seed")
|
173 |
+
base_new_image_button.click(set_new_latent_and_generate_new_image_with_base_model, inputs=[seed_input, prompt_input, num_inference_steps, guidance_scale], outputs=[im_base, base_model_inference_time_output])
|
174 |
+
|
175 |
+
gradio_gui.launch()
|
sd_models.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
import logging
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
import random
|
21 |
+
from pathlib import Path
|
22 |
+
from typing import Optional, Union, List, Callable
|
23 |
+
|
24 |
+
import datasets
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
import torch.nn.functional as F
|
28 |
+
import torch.utils.checkpoint
|
29 |
+
import transformers
|
30 |
+
from datasets import load_dataset
|
31 |
+
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
32 |
+
from packaging import version
|
33 |
+
from torchvision import transforms
|
34 |
+
from tqdm.auto import tqdm
|
35 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
36 |
+
|
37 |
+
import diffusers
|
38 |
+
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel#, StackUNet2DConditionModel
|
39 |
+
from diffusers.optimization import get_scheduler
|
40 |
+
from diffusers.training_utils import EMAModel
|
41 |
+
from diffusers.utils import check_min_version, deprecate
|
42 |
+
from diffusers.utils.import_utils import is_xformers_available
|
43 |
+
|
44 |
+
import time
|
45 |
+
|
46 |
+
from torch.distributions import Normal, Categorical
|
47 |
+
from torch.distributions.multivariate_normal import MultivariateNormal
|
48 |
+
from torch.distributions.mixture_same_family import MixtureSameFamily
|
49 |
+
|
50 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
51 |
+
import torchvision
|
52 |
+
|
53 |
+
import cv2
|
54 |
+
|
55 |
+
def inference_latent(
|
56 |
+
pipeline,
|
57 |
+
prompt: Union[str, List[str]],
|
58 |
+
height: Optional[int] = None,
|
59 |
+
width: Optional[int] = None,
|
60 |
+
num_inference_steps: int = 50,
|
61 |
+
guidance_scale: float = 7.5,
|
62 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
63 |
+
num_images_per_prompt: Optional[int] = 1,
|
64 |
+
eta: float = 0.0,
|
65 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
66 |
+
latents: Optional[torch.FloatTensor] = None,
|
67 |
+
output_type: Optional[str] = "pil",
|
68 |
+
return_dict: bool = True,
|
69 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
70 |
+
callback_steps: Optional[int] = 1,
|
71 |
+
):
|
72 |
+
|
73 |
+
# 0. Default height and width to unet
|
74 |
+
height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
75 |
+
width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
76 |
+
|
77 |
+
# 1. Check inputs. Raise error if not correct
|
78 |
+
#pipeline.check_inputs(prompt, height, width, callback_steps)
|
79 |
+
|
80 |
+
# 2. Define call parameters
|
81 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
82 |
+
device = pipeline._execution_device
|
83 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
84 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
85 |
+
# corresponds to doing no classifier free guidance.
|
86 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
87 |
+
|
88 |
+
# 3. Encode input prompt
|
89 |
+
#setup_seed(0)
|
90 |
+
text_embeddings = pipeline._encode_prompt(
|
91 |
+
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
92 |
+
)
|
93 |
+
|
94 |
+
# 4. Prepare timesteps
|
95 |
+
pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
96 |
+
timesteps = pipeline.scheduler.timesteps
|
97 |
+
|
98 |
+
# 5. Prepare latent variables
|
99 |
+
num_channels_latents = pipeline.unet.in_channels
|
100 |
+
latents = latents.reshape(1, num_channels_latents, 64, 64)
|
101 |
+
|
102 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
103 |
+
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
|
104 |
+
|
105 |
+
# 7. Denoising loop
|
106 |
+
num_warmup_steps = len(timesteps) - \
|
107 |
+
num_inference_steps * pipeline.scheduler.order
|
108 |
+
|
109 |
+
latents_cllt = [latents.detach().clone()]
|
110 |
+
with torch.no_grad():
|
111 |
+
for i, t in enumerate(timesteps):
|
112 |
+
# expand the latents if we are doing classifier free guidance
|
113 |
+
latent_model_input = torch.cat(
|
114 |
+
[latents] * 2) if do_classifier_free_guidance else latents
|
115 |
+
latent_model_input = pipeline.scheduler.scale_model_input(
|
116 |
+
latent_model_input, t)
|
117 |
+
|
118 |
+
noise_pred = pipeline.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
119 |
+
|
120 |
+
# perform guidance
|
121 |
+
if do_classifier_free_guidance:
|
122 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
123 |
+
noise_pred = noise_pred_uncond + guidance_scale * \
|
124 |
+
(noise_pred_text - noise_pred_uncond)
|
125 |
+
|
126 |
+
# compute the previous noisy sample x_t -> x_t-1
|
127 |
+
outputs = pipeline.scheduler.step(
|
128 |
+
noise_pred, t, latents, **extra_step_kwargs)
|
129 |
+
|
130 |
+
latents = outputs.prev_sample
|
131 |
+
|
132 |
+
|
133 |
+
example = {
|
134 |
+
'latent': latents.detach().clone(),
|
135 |
+
'text_embeddings': text_embeddings.chunk(2)[1].detach() if do_classifier_free_guidance else text_embeddings.detach(),
|
136 |
+
}
|
137 |
+
return example
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
def setup_seed(seed):
|
142 |
+
import random
|
143 |
+
torch.manual_seed(seed)
|
144 |
+
torch.cuda.manual_seed_all(seed)
|
145 |
+
np.random.seed(seed)
|
146 |
+
random.seed(seed)
|
147 |
+
torch.backends.cudnn.benchmark = False
|
148 |
+
torch.backends.cudnn.deterministic = True
|
149 |
+
torch.cuda.empty_cache()
|
150 |
+
|
151 |
+
|
152 |
+
class SD_model():
|
153 |
+
|
154 |
+
def __init__(self, pretrained_model_name_or_path):
|
155 |
+
self.pretrained_model_name_or_path = pretrained_model_name_or_path
|
156 |
+
|
157 |
+
# Load scheduler, tokenizer and models.
|
158 |
+
noise_scheduler = DDPMScheduler.from_pretrained(self.pretrained_model_name_or_path, subfolder="scheduler")
|
159 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
160 |
+
self.pretrained_model_name_or_path, subfolder="tokenizer"#, revision=args.revision
|
161 |
+
)
|
162 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
163 |
+
self.pretrained_model_name_or_path, subfolder="text_encoder"#, revision=args.revision
|
164 |
+
)
|
165 |
+
vae = AutoencoderKL.from_pretrained(
|
166 |
+
self.pretrained_model_name_or_path, subfolder="vae"#, revision=args.revision
|
167 |
+
)
|
168 |
+
unet = UNet2DConditionModel.from_pretrained(
|
169 |
+
self.pretrained_model_name_or_path, subfolder="unet"#, revision=args.non_ema_revision
|
170 |
+
)
|
171 |
+
|
172 |
+
|
173 |
+
unet.eval()
|
174 |
+
vae.eval()
|
175 |
+
text_encoder.eval()
|
176 |
+
|
177 |
+
# Freeze vae and text_encoder
|
178 |
+
vae.requires_grad_(False)
|
179 |
+
text_encoder.requires_grad_(False)
|
180 |
+
unet.requires_grad_(False)
|
181 |
+
|
182 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
183 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
184 |
+
weight_dtype = torch.float16
|
185 |
+
self.weight_dtype = weight_dtype
|
186 |
+
device = 'cuda'
|
187 |
+
self.device = device
|
188 |
+
|
189 |
+
# Move text_encode and vae to gpu and cast to weight_dtype
|
190 |
+
text_encoder.to(device, dtype=weight_dtype)
|
191 |
+
vae.to(device, dtype=weight_dtype)
|
192 |
+
unet.to(device, dtype=weight_dtype)
|
193 |
+
|
194 |
+
# Create the pipeline using the trained modules and save it.
|
195 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
196 |
+
self.pretrained_model_name_or_path,
|
197 |
+
text_encoder=text_encoder,
|
198 |
+
vae=vae,
|
199 |
+
unet=unet,
|
200 |
+
torch_dtype=weight_dtype,
|
201 |
+
)
|
202 |
+
pipeline = pipeline.to(device)
|
203 |
+
from diffusers import DPMSolverMultistepScheduler
|
204 |
+
|
205 |
+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
206 |
+
self.pipeline = pipeline
|
207 |
+
|
208 |
+
def set_new_latent_and_generate_new_image(self, seed=None, prompt=None, negative_prompt="", num_inference_steps=25, guidance_scale=5.0):
|
209 |
+
if seed is None:
|
210 |
+
assert False, "Must have a pre-defined random seed"
|
211 |
+
|
212 |
+
if prompt is None:
|
213 |
+
assert False, "Must have a user-specified text prompt"
|
214 |
+
|
215 |
+
setup_seed(seed)
|
216 |
+
self.latents = torch.randn((1, 4*64*64), device=self.device).to(dtype=self.weight_dtype)
|
217 |
+
self.prompt = prompt
|
218 |
+
self.negative_prompt = negative_prompt
|
219 |
+
self.guidance_scale = guidance_scale
|
220 |
+
self.num_inference_steps = num_inference_steps
|
221 |
+
|
222 |
+
prompts = [prompt]
|
223 |
+
negative_prompts = [negative_prompt]
|
224 |
+
|
225 |
+
output = inference_latent(
|
226 |
+
self.pipeline,
|
227 |
+
prompt=prompts,
|
228 |
+
negative_prompt=negative_prompts,
|
229 |
+
num_inference_steps=num_inference_steps,
|
230 |
+
guidance_scale=self.guidance_scale,
|
231 |
+
latents=self.latents.detach().clone(),
|
232 |
+
)
|
233 |
+
|
234 |
+
image = self.pipeline.decode_latents(output['latent'])
|
235 |
+
|
236 |
+
self.org_image = image
|
237 |
+
|
238 |
+
return image
|
239 |
+
|