Spaces:
Runtime error
Runtime error
Create worker_runpod.py
Browse files- worker_runpod.py +234 -0
worker_runpod.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, json, requests, runpod
|
2 |
+
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
import traceback
|
6 |
+
|
7 |
+
import fairscale.nn.model_parallel.initialize as fs_init
|
8 |
+
import gradio as gr
|
9 |
+
import numpy as np
|
10 |
+
from safetensors.torch import load_file
|
11 |
+
import torch
|
12 |
+
import torch.distributed as dist
|
13 |
+
from torchvision.transforms.functional import to_pil_image
|
14 |
+
|
15 |
+
import models
|
16 |
+
from transport import Sampler, create_transport
|
17 |
+
from diffusers.models import AutoencoderKL
|
18 |
+
from transformers import AutoModel, AutoTokenizer
|
19 |
+
|
20 |
+
discord_token = os.getenv('com_camenduru_discord_token')
|
21 |
+
web_uri = os.getenv('com_camenduru_web_uri')
|
22 |
+
web_token = os.getenv('com_camenduru_web_token')
|
23 |
+
|
24 |
+
with torch.inference_mode():
|
25 |
+
path_type = "Linear" # ["Linear", "GVP", "VP"]
|
26 |
+
prediction = "velocity" # ["velocity", "score", "noise"]
|
27 |
+
loss_weight = None # [None, "velocity", "likelihood"]
|
28 |
+
sample_eps = None
|
29 |
+
train_eps = None
|
30 |
+
atol = 1e-6
|
31 |
+
rtol = 1e-3
|
32 |
+
reverse = None
|
33 |
+
likelihood = None
|
34 |
+
rank = 0
|
35 |
+
num_gpus = 1
|
36 |
+
ckpt = "/content/Lumina-T2X/models"
|
37 |
+
ema = True
|
38 |
+
dtype = torch.bfloat16 #["bf16", "fp32"]
|
39 |
+
|
40 |
+
os.environ["MASTER_PORT"] = str(8080)
|
41 |
+
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
42 |
+
os.environ["RANK"] = str(rank)
|
43 |
+
os.environ["WORLD_SIZE"] = str(num_gpus)
|
44 |
+
|
45 |
+
dist.init_process_group("nccl")
|
46 |
+
fs_init.initialize_model_parallel(1)
|
47 |
+
torch.cuda.set_device(rank)
|
48 |
+
train_args = torch.load(os.path.join(ckpt, "model_args.pth"))
|
49 |
+
text_encoder = AutoModel.from_pretrained("4bit/gemma-2b", torch_dtype=dtype, device_map="cuda").eval()
|
50 |
+
cap_feat_dim = text_encoder.config.hidden_size
|
51 |
+
tokenizer = AutoTokenizer.from_pretrained("4bit/gemma-2b")
|
52 |
+
tokenizer.padding_side = "right"
|
53 |
+
|
54 |
+
vae = AutoencoderKL.from_pretrained((f"stabilityai/sd-vae-ft-{train_args.vae}" if train_args.vae != "sdxl" else "stabilityai/sdxl-vae"), torch_dtype=torch.float32).cuda()
|
55 |
+
model = models.__dict__[train_args.model](
|
56 |
+
qk_norm=train_args.qk_norm,
|
57 |
+
cap_feat_dim=cap_feat_dim,
|
58 |
+
)
|
59 |
+
model.eval().to("cuda", dtype=dtype)
|
60 |
+
ckpt = load_file(os.path.join(ckpt, f"consolidated{'_ema' if ema else ''}.{rank:02d}-of-{num_gpus:02d}.safetensors"), device="cpu",)
|
61 |
+
model.load_state_dict(ckpt, strict=True)
|
62 |
+
|
63 |
+
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
|
64 |
+
def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
|
65 |
+
captions = []
|
66 |
+
for caption in prompt_batch:
|
67 |
+
if random.random() < proportion_empty_prompts:
|
68 |
+
captions.append("")
|
69 |
+
elif isinstance(caption, str):
|
70 |
+
captions.append(caption)
|
71 |
+
elif isinstance(caption, (list, np.ndarray)):
|
72 |
+
# take a random caption if there are multiple
|
73 |
+
captions.append(random.choice(caption) if is_train else caption[0])
|
74 |
+
|
75 |
+
with torch.no_grad():
|
76 |
+
text_inputs = tokenizer(
|
77 |
+
captions,
|
78 |
+
padding=True,
|
79 |
+
pad_to_multiple_of=8,
|
80 |
+
max_length=256,
|
81 |
+
truncation=True,
|
82 |
+
return_tensors="pt",
|
83 |
+
)
|
84 |
+
|
85 |
+
text_input_ids = text_inputs.input_ids
|
86 |
+
prompt_masks = text_inputs.attention_mask
|
87 |
+
|
88 |
+
prompt_embeds = text_encoder(
|
89 |
+
input_ids=text_input_ids.cuda(),
|
90 |
+
attention_mask=prompt_masks.cuda(),
|
91 |
+
output_hidden_states=True,
|
92 |
+
).hidden_states[-2]
|
93 |
+
|
94 |
+
return prompt_embeds, prompt_masks
|
95 |
+
|
96 |
+
@torch.inference_mode()
|
97 |
+
def generate(input):
|
98 |
+
values = input["input"]
|
99 |
+
|
100 |
+
cap1 = values['cap1']
|
101 |
+
cap2 = values['cap2']
|
102 |
+
cap3 = values['cap3']
|
103 |
+
cap4 = values['cap4']
|
104 |
+
neg_cap = values['neg_cap']
|
105 |
+
resolution = values['resolution'] # ["2048x1024 (4x1 Grids)","2560x1024 (4x1 Grids)","3072x1024 (4x1 Grids)","1024x1024 (2x2 Grids)","1536x1536 (2x2 Grids)","2048x2048 (2x2 Grids)","1024x2048 (1x4 Grids)","1024x2560 (1x4 Grids)","1024x3072 (1x4 Grids)",]
|
106 |
+
num_sampling_steps = values['num_sampling_steps']
|
107 |
+
cfg_scale = values['cfg_scale']
|
108 |
+
solver = values['solver'] # ["euler", "midpoint", "rk4"]
|
109 |
+
t_shift = values['t_shift']
|
110 |
+
seed = values['seed']
|
111 |
+
scaling_method = values['scaling_method'] # ["Time-aware", "None"]
|
112 |
+
scaling_watershed = values['scaling_watershed']
|
113 |
+
proportional_attn = values['proportional_attn']
|
114 |
+
|
115 |
+
with torch.autocast("cuda", dtype):
|
116 |
+
try:
|
117 |
+
# begin sampler
|
118 |
+
transport = create_transport(
|
119 |
+
path_type,
|
120 |
+
prediction,
|
121 |
+
loss_weight,
|
122 |
+
train_eps,
|
123 |
+
sample_eps,
|
124 |
+
)
|
125 |
+
sampler = Sampler(transport)
|
126 |
+
sample_fn = sampler.sample_ode(
|
127 |
+
sampling_method=solver,
|
128 |
+
num_steps=num_sampling_steps,
|
129 |
+
atol=atol,
|
130 |
+
rtol=rtol,
|
131 |
+
reverse=reverse,
|
132 |
+
time_shifting_factor=t_shift,
|
133 |
+
)
|
134 |
+
# end sampler
|
135 |
+
|
136 |
+
do_extrapolation = "Extrapolation" in resolution
|
137 |
+
split = resolution.split(" ")[1].replace("(", "")
|
138 |
+
w_split, h_split = split.split("x")
|
139 |
+
resolution = resolution.split(" ")[0]
|
140 |
+
w, h = resolution.split("x")
|
141 |
+
w, h = int(w), int(h)
|
142 |
+
latent_w, latent_h = w // 8, h // 8
|
143 |
+
if int(seed) != 0:
|
144 |
+
torch.random.manual_seed(int(seed))
|
145 |
+
z = torch.randn([1, 4, latent_h, latent_w], device="cuda").to(dtype)
|
146 |
+
z = z.repeat(2, 1, 1, 1)
|
147 |
+
|
148 |
+
cap_list = [cap1, cap2, cap3, cap4]
|
149 |
+
global_cap = " ".join(cap_list)
|
150 |
+
with torch.no_grad():
|
151 |
+
if neg_cap != "":
|
152 |
+
cap_feats, cap_mask = encode_prompt(
|
153 |
+
cap_list + [neg_cap] + [global_cap], text_encoder, tokenizer, 0.0
|
154 |
+
)
|
155 |
+
else:
|
156 |
+
cap_feats, cap_mask = encode_prompt(
|
157 |
+
cap_list + [""] + [global_cap], text_encoder, tokenizer, 0.0
|
158 |
+
)
|
159 |
+
|
160 |
+
cap_mask = cap_mask.to(cap_feats.device)
|
161 |
+
|
162 |
+
model_kwargs = dict(
|
163 |
+
cap_feats=cap_feats[:-1],
|
164 |
+
cap_mask=cap_mask[:-1],
|
165 |
+
global_cap_feats=cap_feats[-1:],
|
166 |
+
global_cap_mask=cap_mask[-1:],
|
167 |
+
cfg_scale=cfg_scale,
|
168 |
+
h_split_num=int(h_split),
|
169 |
+
w_split_num=int(w_split),
|
170 |
+
)
|
171 |
+
if proportional_attn:
|
172 |
+
model_kwargs["proportional_attn"] = True
|
173 |
+
model_kwargs["base_seqlen"] = (train_args.image_size // 16) ** 2
|
174 |
+
else:
|
175 |
+
model_kwargs["proportional_attn"] = False
|
176 |
+
model_kwargs["base_seqlen"] = None
|
177 |
+
|
178 |
+
if do_extrapolation and scaling_method == "Time-aware":
|
179 |
+
model_kwargs["scale_factor"] = math.sqrt(w * h / train_args.image_size**2)
|
180 |
+
model_kwargs["scale_watershed"] = scaling_watershed
|
181 |
+
else:
|
182 |
+
model_kwargs["scale_factor"] = 1.0
|
183 |
+
model_kwargs["scale_watershed"] = 1.0
|
184 |
+
|
185 |
+
samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
|
186 |
+
samples = samples[:1]
|
187 |
+
|
188 |
+
factor = 0.18215 if train_args.vae != "sdxl" else 0.13025
|
189 |
+
samples = vae.decode(samples / factor).sample
|
190 |
+
samples = (samples + 1.0) / 2.0
|
191 |
+
samples.clamp_(0.0, 1.0)
|
192 |
+
|
193 |
+
img = to_pil_image(samples[0].float())
|
194 |
+
|
195 |
+
except Exception:
|
196 |
+
print(traceback.format_exc())
|
197 |
+
|
198 |
+
result = img
|
199 |
+
response = None
|
200 |
+
try:
|
201 |
+
source_id = values['source_id']
|
202 |
+
del values['source_id']
|
203 |
+
source_channel = values['source_channel']
|
204 |
+
del values['source_channel']
|
205 |
+
job_id = values['job_id']
|
206 |
+
del values['job_id']
|
207 |
+
default_filename = os.path.basename(result)
|
208 |
+
files = {default_filename: open(result, "rb").read()}
|
209 |
+
payload = {"content": f"{json.dumps(values)} <@{source_id}>"}
|
210 |
+
response = requests.post(
|
211 |
+
f"https://discord.com/api/v9/channels/{source_channel}/messages",
|
212 |
+
data=payload,
|
213 |
+
headers={"authorization": f"Bot {discord_token}"},
|
214 |
+
files=files
|
215 |
+
)
|
216 |
+
response.raise_for_status()
|
217 |
+
except Exception as e:
|
218 |
+
print(f"An unexpected error occurred: {e}")
|
219 |
+
finally:
|
220 |
+
if os.path.exists(result):
|
221 |
+
os.remove(result)
|
222 |
+
|
223 |
+
if response and response.status_code == 200:
|
224 |
+
try:
|
225 |
+
payload = {"jobId": job_id, "result": response.json()['attachments'][0]['url']}
|
226 |
+
requests.post(f"{web_uri}/api/notify", data=json.dumps(payload), headers={'Content-Type': 'application/json', "authorization": f"{web_token}"})
|
227 |
+
except Exception as e:
|
228 |
+
print(f"An unexpected error occurred: {e}")
|
229 |
+
finally:
|
230 |
+
return {"result": response.json()['attachments'][0]['url']}
|
231 |
+
else:
|
232 |
+
return {"result": "ERROR"}
|
233 |
+
|
234 |
+
runpod.serverless.start({"handler": generate})
|