vibs08 commited on
Commit
577df10
·
verified ·
1 Parent(s): 4dafe5e

Upload 4 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. README.md +6 -5
  3. app.py +182 -0
  4. requirements.txt +10 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.whl filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
- title: Text To 3D
3
- emoji: 🚀
4
- colorFrom: green
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 4.41.0
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
+ title: TripoSR
3
+ emoji: 🐳
4
+ colorFrom: gray
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.20.1
8
+ python_version: 3.10.13
9
  app_file: app.py
10
  pinned: false
11
  license: mit
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import boto3
4
+ import json
5
+ import shlex
6
+ import subprocess
7
+ import tempfile
8
+ import time
9
+ import base64
10
+ import gradio as gr
11
+ import numpy as np
12
+ import rembg
13
+ import spaces
14
+ import torch
15
+ from PIL import Image
16
+ from functools import partial
17
+ import io
18
+
19
+ # s3 = boto3.client(
20
+ # 's3',
21
+ # aws_access_key_id="AKIAZW3QSPMIH4RF42UA",
22
+ # aws_secret_access_key="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",
23
+ # region_name='us-east-1'
24
+ # )
25
+
26
+ subprocess.run(shlex.split('pip install wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl'))
27
+
28
+ from tsr.system import TSR
29
+ from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
30
+
31
+
32
+ HEADER = """FRAME AI"""
33
+
34
+ if torch.cuda.is_available():
35
+ device = "cuda:0"
36
+ else:
37
+ device = "cpu"
38
+
39
+ model = TSR.from_pretrained(
40
+ "stabilityai/TripoSR",
41
+ config_name="config.yaml",
42
+ weight_name="model.ckpt",
43
+ )
44
+ model.renderer.set_chunk_size(131072)
45
+ model.to(device)
46
+
47
+ rembg_session = rembg.new_session()
48
+
49
+ def generate_image_from_text(pos_prompt):
50
+ # bedrock_runtime = boto3.client(region_name = 'us-east-1', service_name='bedrock-runtime')
51
+ bedrock_runtime = boto3.client(service_name='bedrock-runtime', aws_access_key_id = "AKIAZW3QSPMIH4RF42UA", aws_secret_access_key = "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", region_name='us-east-1')
52
+ parameters = {'text_prompts': [{'text':pos_prompt, 'weight':1},
53
+ {'text': """Blurry, unnatural, ugly, pixelated obscure, dull, artifacts, duplicate, bad quality, low resolution, cropped, out of frame, out of focus""", 'weight': -1}],
54
+ 'cfg_scale': 7, 'seed': 0, 'samples': 1}
55
+ request_body = json.dumps(parameters)
56
+ response = bedrock_runtime.invoke_model(body=request_body,modelId = 'stability.stable-diffusion-xl-v1')
57
+ response_body = json.loads(response.get('body').read())
58
+ base64_image_data = base64.b64decode(response_body['artifacts'][0]['base64'])
59
+
60
+ return Image.open(io.BytesIO(base64_image_data))
61
+
62
+ def check_input_image(input_image):
63
+ if input_image is None:
64
+ raise gr.Error("No image uploaded!")
65
+
66
+ def preprocess(input_image, do_remove_background, foreground_ratio):
67
+ def fill_background(image):
68
+ image = np.array(image).astype(np.float32) / 255.0
69
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
70
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
71
+ return image
72
+
73
+ if do_remove_background:
74
+ image = input_image.convert("RGB")
75
+ image = remove_background(image, rembg_session)
76
+ image = resize_foreground(image, foreground_ratio)
77
+ image = fill_background(image)
78
+ else:
79
+ image = input_image
80
+ if image.mode == "RGBA":
81
+ image = fill_background(image)
82
+ return image
83
+
84
+ @spaces.GPU
85
+ def generate(image, mc_resolution, formats=["obj", "glb"]):
86
+ scene_codes = model(image, device=device)
87
+ mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
88
+ mesh = to_gradio_3d_orientation(mesh)
89
+
90
+ mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f".glb", delete=False)
91
+ mesh.export(mesh_path_glb.name)
92
+
93
+ mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False)
94
+ mesh.apply_scale([-1, 1, 1]) # Otherwise the visualized .obj will be flipped
95
+ mesh.export(mesh_path_obj.name)
96
+
97
+ return mesh_path_obj.name, mesh_path_glb.name
98
+
99
+ def run_example(text_prompt, do_remove_background, foreground_ratio, mc_resolution):
100
+ # Step 1: Generate the image from text prompt
101
+ image_pil = generate_image_from_text(text_prompt)
102
+
103
+ # Step 2: Preprocess the image
104
+ preprocessed = preprocess(image_pil, do_remove_background, foreground_ratio)
105
+
106
+ # Step 3: Generate the 3D model
107
+ mesh_name_obj, mesh_name_glb = generate(preprocessed, mc_resolution, ["obj", "glb"])
108
+
109
+ return preprocessed, mesh_name_obj, mesh_name_glb
110
+
111
+ with gr.Blocks() as demo:
112
+ gr.Markdown(HEADER)
113
+ with gr.Row(variant="panel"):
114
+ with gr.Column():
115
+ with gr.Row():
116
+ text_prompt = gr.Textbox(
117
+ label="Text Prompt",
118
+ placeholder="Enter a text prompt for image generation"
119
+ )
120
+ input_image = gr.Image(
121
+ label="Generated Image",
122
+ image_mode="RGBA",
123
+ sources="upload",
124
+ type="pil",
125
+ elem_id="content_image",
126
+ visible=False # Hidden since we generate the image from text
127
+ )
128
+ processed_image = gr.Image(label="Processed Image", interactive=False)
129
+ with gr.Row():
130
+ with gr.Group():
131
+ do_remove_background = gr.Checkbox(
132
+ label="Remove Background", value=True
133
+ )
134
+ foreground_ratio = gr.Slider(
135
+ label="Foreground Ratio",
136
+ minimum=0.5,
137
+ maximum=1.0,
138
+ value=0.85,
139
+ step=0.05,
140
+ )
141
+ mc_resolution = gr.Slider(
142
+ label="Marching Cubes Resolution",
143
+ minimum=32,
144
+ maximum=320,
145
+ value=256,
146
+ step=32
147
+ )
148
+ with gr.Row():
149
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
150
+ with gr.Column():
151
+ with gr.Tab("OBJ"):
152
+ output_model_obj = gr.Model3D(
153
+ label="Output Model (OBJ Format)",
154
+ interactive=False,
155
+ )
156
+ gr.Markdown("Note: Downloaded object will be flipped in case of .obj export. Export .glb instead or manually flip it before usage.")
157
+ with gr.Tab("GLB"):
158
+ output_model_glb = gr.Model3D(
159
+ label="Output Model (GLB Format)",
160
+ interactive=False,
161
+ )
162
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
163
+ with gr.Row(variant="panel"):
164
+ gr.Examples(
165
+ examples=[
166
+ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
167
+ ],
168
+ inputs=[text_prompt],
169
+ outputs=[processed_image, output_model_obj, output_model_glb],
170
+ cache_examples=True,
171
+ fn=partial(run_example, do_remove_background=True, foreground_ratio=0.85, mc_resolution=256),
172
+ label="Examples",
173
+ examples_per_page=20
174
+ )
175
+ submit.click(fn=check_input_image, inputs=[text_prompt]).success(
176
+ fn=run_example,
177
+ inputs=[text_prompt, do_remove_background, foreground_ratio, mc_resolution],
178
+ outputs=[processed_image, output_model_obj, output_model_glb],
179
+ )
180
+
181
+ demo.queue(max_size=10)
182
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf==2.3.0
2
+ Pillow==10.1.0
3
+ einops==0.7.0
4
+ torch==2.0.1
5
+ transformers==4.35.0
6
+ trimesh==4.0.5
7
+ rembg
8
+ huggingface-hub
9
+ gradio
10
+ boto3