Spaces:
Runtime error
Runtime error
Create worker_runpod.py
Browse files- worker_runpod.py +151 -0
worker_runpod.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, json, requests, runpod
|
2 |
+
|
3 |
+
discord_token = os.getenv('com_camenduru_discord_token')
|
4 |
+
web_uri = os.getenv('com_camenduru_web_uri')
|
5 |
+
web_token = os.getenv('com_camenduru_web_token')
|
6 |
+
|
7 |
+
import random, time
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
import nodes
|
12 |
+
from nodes import NODE_CLASS_MAPPINGS
|
13 |
+
from nodes import load_custom_node
|
14 |
+
from comfy_extras import nodes_custom_sampler
|
15 |
+
from comfy_extras import nodes_flux
|
16 |
+
from comfy import model_management
|
17 |
+
import gradio as gr
|
18 |
+
|
19 |
+
load_custom_node("/content/ComfyUI/custom_nodes/ComfyUI-LLaVA-OneVision")
|
20 |
+
DualCLIPLoader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
|
21 |
+
UNETLoader = NODE_CLASS_MAPPINGS["UNETLoader"]()
|
22 |
+
VAELoader = NODE_CLASS_MAPPINGS["VAELoader"]()
|
23 |
+
|
24 |
+
LoraLoader = NODE_CLASS_MAPPINGS["LoraLoader"]()
|
25 |
+
FluxGuidance = nodes_flux.NODE_CLASS_MAPPINGS["FluxGuidance"]()
|
26 |
+
RandomNoise = nodes_custom_sampler.NODE_CLASS_MAPPINGS["RandomNoise"]()
|
27 |
+
BasicGuider = nodes_custom_sampler.NODE_CLASS_MAPPINGS["BasicGuider"]()
|
28 |
+
KSamplerSelect = nodes_custom_sampler.NODE_CLASS_MAPPINGS["KSamplerSelect"]()
|
29 |
+
BasicScheduler = nodes_custom_sampler.NODE_CLASS_MAPPINGS["BasicScheduler"]()
|
30 |
+
SamplerCustomAdvanced = nodes_custom_sampler.NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
|
31 |
+
VAEDecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
|
32 |
+
EmptyLatentImage = NODE_CLASS_MAPPINGS["EmptyLatentImage"]()
|
33 |
+
DownloadAndLoadLLaVAOneVisionModel = NODE_CLASS_MAPPINGS["DownloadAndLoadLLaVAOneVisionModel"]()
|
34 |
+
LLaVA_OneVision_Run = NODE_CLASS_MAPPINGS["LLaVA_OneVision_Run"]()
|
35 |
+
LoadImage = NODE_CLASS_MAPPINGS["LoadImage"]()
|
36 |
+
|
37 |
+
with torch.inference_mode():
|
38 |
+
llava_model = DownloadAndLoadLLaVAOneVisionModel.loadmodel("lmms-lab/llava-onevision-qwen2-0.5b-si", "cuda", "bf16", "sdpa")[0]
|
39 |
+
clip = DualCLIPLoader.load_clip("t5xxl_fp16.safetensors", "clip_l.safetensors", "flux")[0]
|
40 |
+
unet = UNETLoader.load_unet("flux1-dev.sft", "default")[0]
|
41 |
+
vae = VAELoader.load_vae("ae.sft")[0]
|
42 |
+
|
43 |
+
def closestNumber(n, m):
|
44 |
+
q = int(n / m)
|
45 |
+
n1 = m * q
|
46 |
+
if (n * m) > 0:
|
47 |
+
n2 = m * (q + 1)
|
48 |
+
else:
|
49 |
+
n2 = m * (q - 1)
|
50 |
+
if abs(n - n1) < abs(n - n2):
|
51 |
+
return n1
|
52 |
+
return n2
|
53 |
+
|
54 |
+
def download_file(url, save_dir='/content/ComfyUI/input'):
|
55 |
+
os.makedirs(save_dir, exist_ok=True)
|
56 |
+
file_name = url.split('/')[-1]
|
57 |
+
file_path = os.path.join(save_dir, file_name)
|
58 |
+
response = requests.get(url)
|
59 |
+
response.raise_for_status()
|
60 |
+
with open(file_path, 'wb') as file:
|
61 |
+
file.write(response.content)
|
62 |
+
return file_path
|
63 |
+
|
64 |
+
@torch.inference_mode()
|
65 |
+
def generate(input):
|
66 |
+
values = input["input"]
|
67 |
+
|
68 |
+
tag_image = values['input_image_check']
|
69 |
+
tag_image = download_file(tag_image)
|
70 |
+
final_width = values['final_width']
|
71 |
+
tag_prompt = values['tag_prompt']
|
72 |
+
additional_prompt = values['additional_prompt']
|
73 |
+
tag_seed = values['tag_seed']
|
74 |
+
tag_temp = values['tag_temp']
|
75 |
+
tag_max_tokens = values['tag_max_tokens']
|
76 |
+
|
77 |
+
seed = values['seed']
|
78 |
+
steps = values['steps']
|
79 |
+
sampler_name = values['sampler_name']
|
80 |
+
scheduler = values['scheduler']
|
81 |
+
guidance = values['guidance']
|
82 |
+
lora_strength_model = values['lora_strength_model']
|
83 |
+
lora_strength_clip = values['lora_strength_clip']
|
84 |
+
lora_file = values['lora_file']
|
85 |
+
|
86 |
+
# model_management.unload_all_models()
|
87 |
+
tag_image_width, tag_image_height = Image.open(tag_image).size
|
88 |
+
tag_image_aspect_ratio = tag_image_width / tag_image_height
|
89 |
+
final_height = final_width / tag_image_aspect_ratio
|
90 |
+
tag_image = LoadImage.load_image(tag_image)[0]
|
91 |
+
if tag_seed == 0:
|
92 |
+
random.seed(int(time.time()))
|
93 |
+
tag_seed = random.randint(0, 18446744073709551615)
|
94 |
+
print(tag_seed)
|
95 |
+
positive_prompt = LLaVA_OneVision_Run.run(tag_image, llava_model, tag_prompt, tag_max_tokens, True, tag_temp, tag_seed)[0]
|
96 |
+
positive_prompt = f"{additional_prompt} {positive_prompt}"
|
97 |
+
|
98 |
+
if seed == 0:
|
99 |
+
random.seed(int(time.time()))
|
100 |
+
seed = random.randint(0, 18446744073709551615)
|
101 |
+
print(seed)
|
102 |
+
unet_lora, clip_lora = LoraLoader.load_lora(unet, clip, lora_file, lora_strength_model, lora_strength_clip)
|
103 |
+
cond, pooled = clip_lora.encode_from_tokens(clip_lora.tokenize(positive_prompt), return_pooled=True)
|
104 |
+
cond = [[cond, {"pooled_output": pooled}]]
|
105 |
+
cond = FluxGuidance.append(cond, guidance)[0]
|
106 |
+
noise = RandomNoise.get_noise(seed)[0]
|
107 |
+
guider = BasicGuider.get_guider(unet_lora, cond)[0]
|
108 |
+
sampler = KSamplerSelect.get_sampler(sampler_name)[0]
|
109 |
+
sigmas = BasicScheduler.get_sigmas(unet_lora, scheduler, steps, 1.0)[0]
|
110 |
+
latent_image = EmptyLatentImage.generate(closestNumber(final_width, 16), closestNumber(final_height, 16))[0]
|
111 |
+
sample, sample_denoised = SamplerCustomAdvanced.sample(noise, guider, sampler, sigmas, latent_image)
|
112 |
+
decoded = VAEDecode.decode(vae, sample)[0].detach()
|
113 |
+
Image.fromarray(np.array(decoded*255, dtype=np.uint8)[0]).save("/content/onevision_flux.png")
|
114 |
+
|
115 |
+
result = "/content/onevision_flux.png"
|
116 |
+
response = None
|
117 |
+
try:
|
118 |
+
source_id = values['source_id']
|
119 |
+
del values['source_id']
|
120 |
+
source_channel = values['source_channel']
|
121 |
+
del values['source_channel']
|
122 |
+
job_id = values['job_id']
|
123 |
+
del values['job_id']
|
124 |
+
default_filename = os.path.basename(result)
|
125 |
+
files = {default_filename: open(result, "rb").read()}
|
126 |
+
payload = {"content": f"{json.dumps(values)} <@{source_id}>"}
|
127 |
+
response = requests.post(
|
128 |
+
f"https://discord.com/api/v9/channels/{source_channel}/messages",
|
129 |
+
data=payload,
|
130 |
+
headers={"authorization": f"Bot {discord_token}"},
|
131 |
+
files=files
|
132 |
+
)
|
133 |
+
response.raise_for_status()
|
134 |
+
except Exception as e:
|
135 |
+
print(f"An unexpected error occurred: {e}")
|
136 |
+
finally:
|
137 |
+
if os.path.exists(result):
|
138 |
+
os.remove(result)
|
139 |
+
|
140 |
+
if response and response.status_code == 200:
|
141 |
+
try:
|
142 |
+
payload = {"jobId": job_id, "result": response.json()['attachments'][0]['url']}
|
143 |
+
requests.post(f"{web_uri}/api/notify", data=json.dumps(payload), headers={'Content-Type': 'application/json', "authorization": f"{web_token}"})
|
144 |
+
except Exception as e:
|
145 |
+
print(f"An unexpected error occurred: {e}")
|
146 |
+
finally:
|
147 |
+
return {"result": response.json()['attachments'][0]['url']}
|
148 |
+
else:
|
149 |
+
return {"result": "ERROR"}
|
150 |
+
|
151 |
+
runpod.serverless.start({"handler": generate})
|