Spaces:
Sleeping
Sleeping
Refactor code for improved performance and readability
Browse files- .gitignore +2 -0
- TripoSR +1 -0
- app.py +13 -0
- imgGen.py +28 -0
- requirements.txt +12 -0
- worker.py +86 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
output.png
|
TripoSR
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 8e51fec8095c9eae20e6ea7c9aef6368c5631a21
|
app.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from worker import worker # import the worker function
|
3 |
+
|
4 |
+
def greet(name):
|
5 |
+
return "Hello " + name + "!!"
|
6 |
+
|
7 |
+
def kickoff_worker():
|
8 |
+
worker() # call the worker function
|
9 |
+
|
10 |
+
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
11 |
+
iface.launch()
|
12 |
+
print("Launching worker...")
|
13 |
+
kickoff_worker() # kickoff the worker after launching the interface
|
imgGen.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
+
from safetensors.torch import load_file
|
5 |
+
|
6 |
+
|
7 |
+
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
8 |
+
repo = "ByteDance/SDXL-Lightning"
|
9 |
+
ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting!
|
10 |
+
|
11 |
+
|
12 |
+
# Load model.
|
13 |
+
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
|
14 |
+
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
|
15 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
|
16 |
+
|
17 |
+
# Ensure sampler uses "trailing" timesteps.
|
18 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
19 |
+
|
20 |
+
def generateTransparentImage(text):
|
21 |
+
# Ensure using the same inference steps as the loaded model and CFG set to 0.
|
22 |
+
image = pipe(text+', full body, transparent background', num_inference_steps=4, guidance_scale=0).images[0]
|
23 |
+
return image
|
24 |
+
|
25 |
+
if __name__ == "__main__":
|
26 |
+
text = "a cat"
|
27 |
+
img = generateTransparentImage(text)
|
28 |
+
img.save("output.png")
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
omegaconf==2.3.0
|
2 |
+
Pillow==10.1.0
|
3 |
+
einops==0.7.0
|
4 |
+
git+https://github.com/tatsy/torchmcubes.git
|
5 |
+
diffusers["torch"]
|
6 |
+
transformers==4.35.0
|
7 |
+
trimesh==4.0.5
|
8 |
+
rembg
|
9 |
+
huggingface-hub
|
10 |
+
imageio[ffmpeg]
|
11 |
+
setuptools --upgrade
|
12 |
+
torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
worker.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
from supabase import create_client, Client
|
3 |
+
from imgGen import generateTransparentImage
|
4 |
+
import sys
|
5 |
+
sys.path.append('./TripoSR')
|
6 |
+
|
7 |
+
import TripoSR.obj_gen as obj_gen
|
8 |
+
import os
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
import time
|
11 |
+
load_dotenv()
|
12 |
+
|
13 |
+
|
14 |
+
url: str = os.environ.get("SUPABASE_URL")
|
15 |
+
key: str = os.environ.get("SUPABASE_KEY")
|
16 |
+
supabase: Client = create_client(url, key)
|
17 |
+
|
18 |
+
def check_queue():
|
19 |
+
try:
|
20 |
+
tasks = supabase.table("Tasks").select("*").eq("status", "pending").execute()
|
21 |
+
assert len(tasks.data) > 0
|
22 |
+
if len(tasks.data) > 0:
|
23 |
+
return tasks.data[0]
|
24 |
+
else:
|
25 |
+
return None
|
26 |
+
except Exception as e:
|
27 |
+
print(f"Error checking queue: {e}")
|
28 |
+
return None
|
29 |
+
|
30 |
+
|
31 |
+
def generate_image(text):
|
32 |
+
try:
|
33 |
+
img = generateTransparentImage(text)
|
34 |
+
return img
|
35 |
+
except Exception as e:
|
36 |
+
print(f"Error generating image: {e}")
|
37 |
+
return None
|
38 |
+
|
39 |
+
|
40 |
+
def create_obj_file(img, task_id):
|
41 |
+
try:
|
42 |
+
obj_gen.generate_obj_from_image(img, 'task_'+str(task_id)+'.obj')
|
43 |
+
except Exception as e:
|
44 |
+
print(f"Error creating obj file: {e}")
|
45 |
+
supabase.table("Tasks").update({"status": "error"}).eq("id", task_id).execute()
|
46 |
+
|
47 |
+
def send_back_to_supabase(task_id):
|
48 |
+
# check that a file was created
|
49 |
+
if os.path.exists('task_'+str(task_id)+'.obj'):
|
50 |
+
try:
|
51 |
+
with open('task_'+str(task_id)+'.obj', 'rb') as file:
|
52 |
+
data = file.read()
|
53 |
+
supabase.storage.from_('Results').upload('task_'+str(task_id)+'.obj', data)
|
54 |
+
public_url = supabase.storage.from_('Results').get_public_url('task_'+str(task_id)+'.obj')
|
55 |
+
supabase.table("Tasks").update({"status": "complete","result":public_url}).eq("id", task_id).execute()
|
56 |
+
os.remove('task_'+str(task_id)+'.obj')
|
57 |
+
except Exception as e:
|
58 |
+
print(f"Error sending file back to Supabase: {e}")
|
59 |
+
supabase.table("Tasks").update({"status": "error"}).eq("id", task_id).execute()
|
60 |
+
|
61 |
+
else:
|
62 |
+
print(f"Error: No file was created for task {task_id}")
|
63 |
+
|
64 |
+
def worker():
|
65 |
+
while True:
|
66 |
+
task = check_queue()
|
67 |
+
if task:
|
68 |
+
supabase.table("Tasks").update({"status": "processing"}).eq("id", task['id']).execute()
|
69 |
+
print(f"Processing task {task['id']}")
|
70 |
+
img = generate_image(task["text"])
|
71 |
+
if img:
|
72 |
+
print(f"Image generated for task {task['id']}")
|
73 |
+
create_obj_file(img,task["id"])
|
74 |
+
send_back_to_supabase(task["id"])
|
75 |
+
print(f"Task {task['id']} completed")
|
76 |
+
else:
|
77 |
+
print(f"Error generating image for task {task['id']}")
|
78 |
+
supabase.table("Tasks").update({"status": "error"}).eq("id", task['id']).execute()
|
79 |
+
|
80 |
+
else:
|
81 |
+
print("No pending tasks in the queue")
|
82 |
+
|
83 |
+
time.sleep(2) # Add a 2 second delay between checks
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
worker()
|