CazC commited on
Commit
5a2deaa
·
1 Parent(s): 1710a77

Refactor code for improved performance and readability

Browse files
Files changed (6) hide show
  1. .gitignore +2 -0
  2. TripoSR +1 -0
  3. app.py +13 -0
  4. imgGen.py +28 -0
  5. requirements.txt +12 -0
  6. worker.py +86 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .env
2
+ output.png
TripoSR ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 8e51fec8095c9eae20e6ea7c9aef6368c5631a21
app.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from worker import worker # import the worker function
3
+
4
+ def greet(name):
5
+ return "Hello " + name + "!!"
6
+
7
+ def kickoff_worker():
8
+ worker() # call the worker function
9
+
10
+ iface = gr.Interface(fn=greet, inputs="text", outputs="text")
11
+ iface.launch()
12
+ print("Launching worker...")
13
+ kickoff_worker() # kickoff the worker after launching the interface
imgGen.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
3
+ from huggingface_hub import hf_hub_download
4
+ from safetensors.torch import load_file
5
+
6
+
7
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
8
+ repo = "ByteDance/SDXL-Lightning"
9
+ ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting!
10
+
11
+
12
+ # Load model.
13
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
14
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
15
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
16
+
17
+ # Ensure sampler uses "trailing" timesteps.
18
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
19
+
20
+ def generateTransparentImage(text):
21
+ # Ensure using the same inference steps as the loaded model and CFG set to 0.
22
+ image = pipe(text+', full body, transparent background', num_inference_steps=4, guidance_scale=0).images[0]
23
+ return image
24
+
25
+ if __name__ == "__main__":
26
+ text = "a cat"
27
+ img = generateTransparentImage(text)
28
+ img.save("output.png")
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf==2.3.0
2
+ Pillow==10.1.0
3
+ einops==0.7.0
4
+ git+https://github.com/tatsy/torchmcubes.git
5
+ diffusers["torch"]
6
+ transformers==4.35.0
7
+ trimesh==4.0.5
8
+ rembg
9
+ huggingface-hub
10
+ imageio[ffmpeg]
11
+ setuptools --upgrade
12
+ torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
worker.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ from supabase import create_client, Client
3
+ from imgGen import generateTransparentImage
4
+ import sys
5
+ sys.path.append('./TripoSR')
6
+
7
+ import TripoSR.obj_gen as obj_gen
8
+ import os
9
+ from dotenv import load_dotenv
10
+ import time
11
+ load_dotenv()
12
+
13
+
14
+ url: str = os.environ.get("SUPABASE_URL")
15
+ key: str = os.environ.get("SUPABASE_KEY")
16
+ supabase: Client = create_client(url, key)
17
+
18
+ def check_queue():
19
+ try:
20
+ tasks = supabase.table("Tasks").select("*").eq("status", "pending").execute()
21
+ assert len(tasks.data) > 0
22
+ if len(tasks.data) > 0:
23
+ return tasks.data[0]
24
+ else:
25
+ return None
26
+ except Exception as e:
27
+ print(f"Error checking queue: {e}")
28
+ return None
29
+
30
+
31
+ def generate_image(text):
32
+ try:
33
+ img = generateTransparentImage(text)
34
+ return img
35
+ except Exception as e:
36
+ print(f"Error generating image: {e}")
37
+ return None
38
+
39
+
40
+ def create_obj_file(img, task_id):
41
+ try:
42
+ obj_gen.generate_obj_from_image(img, 'task_'+str(task_id)+'.obj')
43
+ except Exception as e:
44
+ print(f"Error creating obj file: {e}")
45
+ supabase.table("Tasks").update({"status": "error"}).eq("id", task_id).execute()
46
+
47
+ def send_back_to_supabase(task_id):
48
+ # check that a file was created
49
+ if os.path.exists('task_'+str(task_id)+'.obj'):
50
+ try:
51
+ with open('task_'+str(task_id)+'.obj', 'rb') as file:
52
+ data = file.read()
53
+ supabase.storage.from_('Results').upload('task_'+str(task_id)+'.obj', data)
54
+ public_url = supabase.storage.from_('Results').get_public_url('task_'+str(task_id)+'.obj')
55
+ supabase.table("Tasks").update({"status": "complete","result":public_url}).eq("id", task_id).execute()
56
+ os.remove('task_'+str(task_id)+'.obj')
57
+ except Exception as e:
58
+ print(f"Error sending file back to Supabase: {e}")
59
+ supabase.table("Tasks").update({"status": "error"}).eq("id", task_id).execute()
60
+
61
+ else:
62
+ print(f"Error: No file was created for task {task_id}")
63
+
64
+ def worker():
65
+ while True:
66
+ task = check_queue()
67
+ if task:
68
+ supabase.table("Tasks").update({"status": "processing"}).eq("id", task['id']).execute()
69
+ print(f"Processing task {task['id']}")
70
+ img = generate_image(task["text"])
71
+ if img:
72
+ print(f"Image generated for task {task['id']}")
73
+ create_obj_file(img,task["id"])
74
+ send_back_to_supabase(task["id"])
75
+ print(f"Task {task['id']} completed")
76
+ else:
77
+ print(f"Error generating image for task {task['id']}")
78
+ supabase.table("Tasks").update({"status": "error"}).eq("id", task['id']).execute()
79
+
80
+ else:
81
+ print("No pending tasks in the queue")
82
+
83
+ time.sleep(2) # Add a 2 second delay between checks
84
+
85
+ if __name__ == "__main__":
86
+ worker()