camenduru commited on
Commit
d0b021f
1 Parent(s): f52b3dd

Rename worker.py to worker_runpod.py

Browse files
Files changed (1) hide show
  1. worker.py → worker_runpod.py +48 -27
worker.py → worker_runpod.py RENAMED
@@ -1,37 +1,57 @@
1
- from diffusers import AutoPipelineForText2Image
 
 
2
  import torch
3
- import json, os, requests
4
- import runpod
 
 
5
 
6
  discord_token = os.getenv('com_camenduru_discord_token')
7
  web_uri = os.getenv('com_camenduru_web_uri')
8
  web_token = os.getenv('com_camenduru_web_token')
9
 
10
- pipe = AutoPipelineForText2Image.from_pretrained(
11
- "misri/cyberrealisticXL_v11VAE",
12
- torch_dtype=torch.float16,
13
- variant="fp16",
14
- requires_safety_checker=False).to("cuda:0")
15
-
16
- def closestNumber(n, m):
17
- q = int(n / m)
18
- n1 = m * q
19
- if (n * m) > 0:
20
- n2 = m * (q + 1)
21
- else:
22
- n2 = m * (q - 1)
23
- if abs(n - n1) < abs(n - n2):
24
- return n1
25
- return n2
26
 
 
27
  def generate(input):
28
  values = input["input"]
29
- width = closestNumber(values['width'], 8)
30
- height = closestNumber(values['height'], 8)
31
- images = pipe(values['prompt'], negative_prompt=values['negative_prompt'], num_inference_steps=25, guidance_scale=7.5, width=width, height=height)
32
- result = f"/content/{input['id']}.png"
33
- images.images[0].save(result)
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  response = None
36
  try:
37
  source_id = values['source_id']
@@ -40,7 +60,8 @@ def generate(input):
40
  del values['source_channel']
41
  job_id = values['job_id']
42
  del values['job_id']
43
- files = {f"image.png": open(result, "rb").read()}
 
44
  payload = {"content": f"{json.dumps(values)} <@{source_id}>"}
45
  response = requests.post(
46
  f"https://discord.com/api/v9/channels/{source_channel}/messages",
@@ -66,4 +87,4 @@ def generate(input):
66
  else:
67
  return {"result": "ERROR"}
68
 
69
- runpod.serverless.start({"handler": generate})
 
1
+ import os, json, requests, runpod
2
+
3
+ import random
4
  import torch
5
+ import numpy as np
6
+ from PIL import Image
7
+ from comfy.sd import load_checkpoint_guess_config
8
+ import nodes
9
 
10
  discord_token = os.getenv('com_camenduru_discord_token')
11
  web_uri = os.getenv('com_camenduru_web_uri')
12
  web_token = os.getenv('com_camenduru_web_token')
13
 
14
+ with torch.inference_mode():
15
+ model_patcher, clip, vae, clipvision = load_checkpoint_guess_config("/content/ComfyUI/models/checkpoints/model.safetensors", output_vae=True, output_clip=True, embedding_directory=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ @torch.inference_mode()
18
  def generate(input):
19
  values = input["input"]
20
+
21
+ positive_prompt = values['positive_prompt']
22
+ negative_prompt = values['negative_prompt']
23
+ width = values['width']
24
+ height = values['height']
25
+ seed = values['seed']
26
+ steps = values['steps']
27
+ cfg = values['cfg']
28
+ sampler_name = values['sampler_name']
29
+ scheduler = values['scheduler']
30
+
31
+ latent = {"samples":torch.zeros([1, 4, height // 8, width // 8])}
32
+ cond, pooled = clip.encode_from_tokens(clip.tokenize(positive_prompt), return_pooled=True)
33
+ cond = [[cond, {"pooled_output": pooled}]]
34
+ n_cond, n_pooled = clip.encode_from_tokens(clip.tokenize(negative_prompt), return_pooled=True)
35
+ n_cond = [[n_cond, {"pooled_output": n_pooled}]]
36
+ if seed == 0:
37
+ seed = random.randint(0, 18446744073709551615)
38
+ print(seed)
39
+ sample = nodes.common_ksampler(model=model_patcher,
40
+ seed=seed,
41
+ steps=steps,
42
+ cfg=cfg,
43
+ sampler_name=sampler_name,
44
+ scheduler=scheduler,
45
+ positive=cond,
46
+ negative=n_cond,
47
+ latent=latent,
48
+ denoise=1)
49
+ sample = sample[0]["samples"].to(torch.float16)
50
+ vae.first_stage_model.cuda()
51
+ decoded = vae.decode_tiled(sample).detach()
52
+ Image.fromarray(np.array(decoded*255, dtype=np.uint8)[0]).save("/content/output_image.png")
53
+
54
+ result = "/content/output_image.png"
55
  response = None
56
  try:
57
  source_id = values['source_id']
 
60
  del values['source_channel']
61
  job_id = values['job_id']
62
  del values['job_id']
63
+ default_filename = os.path.basename(result)
64
+ files = {default_filename: open(result, "rb").read()}
65
  payload = {"content": f"{json.dumps(values)} <@{source_id}>"}
66
  response = requests.post(
67
  f"https://discord.com/api/v9/channels/{source_channel}/messages",
 
87
  else:
88
  return {"result": "ERROR"}
89
 
90
+ runpod.serverless.start({"handler": generate})