camenduru commited on
Commit
20b1f60
·
verified ·
1 Parent(s): d4dda72

Create worker_runpod.py

Browse files
Files changed (1) hide show
  1. worker_runpod.py +234 -0
worker_runpod.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, requests, runpod
2
+
3
+ import math
4
+ import random
5
+ import traceback
6
+
7
+ import fairscale.nn.model_parallel.initialize as fs_init
8
+ import gradio as gr
9
+ import numpy as np
10
+ from safetensors.torch import load_file
11
+ import torch
12
+ import torch.distributed as dist
13
+ from torchvision.transforms.functional import to_pil_image
14
+
15
+ import models
16
+ from transport import Sampler, create_transport
17
+ from diffusers.models import AutoencoderKL
18
+ from transformers import AutoModel, AutoTokenizer
19
+
20
+ discord_token = os.getenv('com_camenduru_discord_token')
21
+ web_uri = os.getenv('com_camenduru_web_uri')
22
+ web_token = os.getenv('com_camenduru_web_token')
23
+
24
+ with torch.inference_mode():
25
+ path_type = "Linear" # ["Linear", "GVP", "VP"]
26
+ prediction = "velocity" # ["velocity", "score", "noise"]
27
+ loss_weight = None # [None, "velocity", "likelihood"]
28
+ sample_eps = None
29
+ train_eps = None
30
+ atol = 1e-6
31
+ rtol = 1e-3
32
+ reverse = None
33
+ likelihood = None
34
+ rank = 0
35
+ num_gpus = 1
36
+ ckpt = "/content/Lumina-T2X/models"
37
+ ema = True
38
+ dtype = torch.bfloat16 #["bf16", "fp32"]
39
+
40
+ os.environ["MASTER_PORT"] = str(8080)
41
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
42
+ os.environ["RANK"] = str(rank)
43
+ os.environ["WORLD_SIZE"] = str(num_gpus)
44
+
45
+ dist.init_process_group("nccl")
46
+ fs_init.initialize_model_parallel(1)
47
+ torch.cuda.set_device(rank)
48
+ train_args = torch.load(os.path.join(ckpt, "model_args.pth"))
49
+ text_encoder = AutoModel.from_pretrained("4bit/gemma-2b", torch_dtype=dtype, device_map="cuda").eval()
50
+ cap_feat_dim = text_encoder.config.hidden_size
51
+ tokenizer = AutoTokenizer.from_pretrained("4bit/gemma-2b")
52
+ tokenizer.padding_side = "right"
53
+
54
+ vae = AutoencoderKL.from_pretrained((f"stabilityai/sd-vae-ft-{train_args.vae}" if train_args.vae != "sdxl" else "stabilityai/sdxl-vae"), torch_dtype=torch.float32).cuda()
55
+ model = models.__dict__[train_args.model](
56
+ qk_norm=train_args.qk_norm,
57
+ cap_feat_dim=cap_feat_dim,
58
+ )
59
+ model.eval().to("cuda", dtype=dtype)
60
+ ckpt = load_file(os.path.join(ckpt, f"consolidated{'_ema' if ema else ''}.{rank:02d}-of-{num_gpus:02d}.safetensors"), device="cpu",)
61
+ model.load_state_dict(ckpt, strict=True)
62
+
63
+ # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
64
+ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
65
+ captions = []
66
+ for caption in prompt_batch:
67
+ if random.random() < proportion_empty_prompts:
68
+ captions.append("")
69
+ elif isinstance(caption, str):
70
+ captions.append(caption)
71
+ elif isinstance(caption, (list, np.ndarray)):
72
+ # take a random caption if there are multiple
73
+ captions.append(random.choice(caption) if is_train else caption[0])
74
+
75
+ with torch.no_grad():
76
+ text_inputs = tokenizer(
77
+ captions,
78
+ padding=True,
79
+ pad_to_multiple_of=8,
80
+ max_length=256,
81
+ truncation=True,
82
+ return_tensors="pt",
83
+ )
84
+
85
+ text_input_ids = text_inputs.input_ids
86
+ prompt_masks = text_inputs.attention_mask
87
+
88
+ prompt_embeds = text_encoder(
89
+ input_ids=text_input_ids.cuda(),
90
+ attention_mask=prompt_masks.cuda(),
91
+ output_hidden_states=True,
92
+ ).hidden_states[-2]
93
+
94
+ return prompt_embeds, prompt_masks
95
+
96
+ @torch.inference_mode()
97
+ def generate(input):
98
+ values = input["input"]
99
+
100
+ cap1 = values['cap1']
101
+ cap2 = values['cap2']
102
+ cap3 = values['cap3']
103
+ cap4 = values['cap4']
104
+ neg_cap = values['neg_cap']
105
+ resolution = values['resolution'] # ["2048x1024 (4x1 Grids)","2560x1024 (4x1 Grids)","3072x1024 (4x1 Grids)","1024x1024 (2x2 Grids)","1536x1536 (2x2 Grids)","2048x2048 (2x2 Grids)","1024x2048 (1x4 Grids)","1024x2560 (1x4 Grids)","1024x3072 (1x4 Grids)",]
106
+ num_sampling_steps = values['num_sampling_steps']
107
+ cfg_scale = values['cfg_scale']
108
+ solver = values['solver'] # ["euler", "midpoint", "rk4"]
109
+ t_shift = values['t_shift']
110
+ seed = values['seed']
111
+ scaling_method = values['scaling_method'] # ["Time-aware", "None"]
112
+ scaling_watershed = values['scaling_watershed']
113
+ proportional_attn = values['proportional_attn']
114
+
115
+ with torch.autocast("cuda", dtype):
116
+ try:
117
+ # begin sampler
118
+ transport = create_transport(
119
+ path_type,
120
+ prediction,
121
+ loss_weight,
122
+ train_eps,
123
+ sample_eps,
124
+ )
125
+ sampler = Sampler(transport)
126
+ sample_fn = sampler.sample_ode(
127
+ sampling_method=solver,
128
+ num_steps=num_sampling_steps,
129
+ atol=atol,
130
+ rtol=rtol,
131
+ reverse=reverse,
132
+ time_shifting_factor=t_shift,
133
+ )
134
+ # end sampler
135
+
136
+ do_extrapolation = "Extrapolation" in resolution
137
+ split = resolution.split(" ")[1].replace("(", "")
138
+ w_split, h_split = split.split("x")
139
+ resolution = resolution.split(" ")[0]
140
+ w, h = resolution.split("x")
141
+ w, h = int(w), int(h)
142
+ latent_w, latent_h = w // 8, h // 8
143
+ if int(seed) != 0:
144
+ torch.random.manual_seed(int(seed))
145
+ z = torch.randn([1, 4, latent_h, latent_w], device="cuda").to(dtype)
146
+ z = z.repeat(2, 1, 1, 1)
147
+
148
+ cap_list = [cap1, cap2, cap3, cap4]
149
+ global_cap = " ".join(cap_list)
150
+ with torch.no_grad():
151
+ if neg_cap != "":
152
+ cap_feats, cap_mask = encode_prompt(
153
+ cap_list + [neg_cap] + [global_cap], text_encoder, tokenizer, 0.0
154
+ )
155
+ else:
156
+ cap_feats, cap_mask = encode_prompt(
157
+ cap_list + [""] + [global_cap], text_encoder, tokenizer, 0.0
158
+ )
159
+
160
+ cap_mask = cap_mask.to(cap_feats.device)
161
+
162
+ model_kwargs = dict(
163
+ cap_feats=cap_feats[:-1],
164
+ cap_mask=cap_mask[:-1],
165
+ global_cap_feats=cap_feats[-1:],
166
+ global_cap_mask=cap_mask[-1:],
167
+ cfg_scale=cfg_scale,
168
+ h_split_num=int(h_split),
169
+ w_split_num=int(w_split),
170
+ )
171
+ if proportional_attn:
172
+ model_kwargs["proportional_attn"] = True
173
+ model_kwargs["base_seqlen"] = (train_args.image_size // 16) ** 2
174
+ else:
175
+ model_kwargs["proportional_attn"] = False
176
+ model_kwargs["base_seqlen"] = None
177
+
178
+ if do_extrapolation and scaling_method == "Time-aware":
179
+ model_kwargs["scale_factor"] = math.sqrt(w * h / train_args.image_size**2)
180
+ model_kwargs["scale_watershed"] = scaling_watershed
181
+ else:
182
+ model_kwargs["scale_factor"] = 1.0
183
+ model_kwargs["scale_watershed"] = 1.0
184
+
185
+ samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
186
+ samples = samples[:1]
187
+
188
+ factor = 0.18215 if train_args.vae != "sdxl" else 0.13025
189
+ samples = vae.decode(samples / factor).sample
190
+ samples = (samples + 1.0) / 2.0
191
+ samples.clamp_(0.0, 1.0)
192
+
193
+ img = to_pil_image(samples[0].float())
194
+
195
+ except Exception:
196
+ print(traceback.format_exc())
197
+
198
+ result = img
199
+ response = None
200
+ try:
201
+ source_id = values['source_id']
202
+ del values['source_id']
203
+ source_channel = values['source_channel']
204
+ del values['source_channel']
205
+ job_id = values['job_id']
206
+ del values['job_id']
207
+ default_filename = os.path.basename(result)
208
+ files = {default_filename: open(result, "rb").read()}
209
+ payload = {"content": f"{json.dumps(values)} <@{source_id}>"}
210
+ response = requests.post(
211
+ f"https://discord.com/api/v9/channels/{source_channel}/messages",
212
+ data=payload,
213
+ headers={"authorization": f"Bot {discord_token}"},
214
+ files=files
215
+ )
216
+ response.raise_for_status()
217
+ except Exception as e:
218
+ print(f"An unexpected error occurred: {e}")
219
+ finally:
220
+ if os.path.exists(result):
221
+ os.remove(result)
222
+
223
+ if response and response.status_code == 200:
224
+ try:
225
+ payload = {"jobId": job_id, "result": response.json()['attachments'][0]['url']}
226
+ requests.post(f"{web_uri}/api/notify", data=json.dumps(payload), headers={'Content-Type': 'application/json', "authorization": f"{web_token}"})
227
+ except Exception as e:
228
+ print(f"An unexpected error occurred: {e}")
229
+ finally:
230
+ return {"result": response.json()['attachments'][0]['url']}
231
+ else:
232
+ return {"result": "ERROR"}
233
+
234
+ runpod.serverless.start({"handler": generate})