Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,7 @@ from torch import nn
|
|
11 |
from transformers import (
|
12 |
AutoTokenizer, Qwen2ForCausalLM, Qwen2Model, PreTrainedModel)
|
13 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
14 |
|
15 |
|
16 |
class FourierPointEncoder(nn.Module):
|
@@ -18,13 +19,13 @@ class FourierPointEncoder(nn.Module):
|
|
18 |
super().__init__()
|
19 |
frequencies = 2.0 ** torch.arange(8, dtype=torch.float32)
|
20 |
self.register_buffer('frequencies', frequencies, persistent=False)
|
21 |
-
self.projection = nn.Linear(
|
22 |
|
23 |
def forward(self, points):
|
24 |
-
x = points
|
25 |
x = (x.unsqueeze(-1) * self.frequencies).view(*x.shape[:-1], -1)
|
26 |
-
x = torch.cat((points
|
27 |
-
x = self.projection(
|
28 |
return x
|
29 |
|
30 |
|
@@ -113,15 +114,11 @@ class CADRecode(Qwen2ForCausalLM):
|
|
113 |
return model_inputs
|
114 |
|
115 |
|
116 |
-
def mesh_to_point_cloud(mesh, n_points=256):
|
117 |
-
vertices,
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
), axis=1)
|
122 |
-
ids = np.lexsort((point_cloud[:, 0], point_cloud[:, 1], point_cloud[:, 2]))
|
123 |
-
point_cloud = point_cloud[ids]
|
124 |
-
return point_cloud
|
125 |
|
126 |
|
127 |
def py_string_to_mesh_file(py_string, mesh_path, queue):
|
@@ -219,9 +216,7 @@ def run():
|
|
219 |
with gr.Row(equal_height=True):
|
220 |
in_model = gr.Model3D(label='1. Input Mesh', interactive=True)
|
221 |
point_model = gr.Model3D(label='2. Sampled Point Cloud', display_mode='point_cloud', interactive=False)
|
222 |
-
out_model = gr.Model3D(
|
223 |
-
label='4. Result CAD Model', interactive=False
|
224 |
-
)
|
225 |
|
226 |
with gr.Row():
|
227 |
with gr.Column():
|
@@ -230,27 +225,30 @@ def run():
|
|
230 |
with gr.Row():
|
231 |
gr.Examples(
|
232 |
examples=[
|
233 |
-
['./data/49215_5368e45e_0000.stl',
|
234 |
-
['./data/
|
235 |
-
['./data/User Library-engrenage.stl', 18],
|
236 |
['./data/00010900.stl', 42],
|
237 |
-
['./data/
|
238 |
-
['./data/
|
239 |
-
['./data/
|
240 |
-
['./data/41473_c2137170_0023.stl', 42],
|
241 |
-
['./data/22447_4062c6cb_0011.stl', 42],
|
242 |
['./data/27694_7801dc67_0017.stl', 45],
|
243 |
-
['./data/49562_6df35938_0005.stl',
|
244 |
-
['./data/131709_8b86dfb6_0000.stl', 45],
|
245 |
['./data/00081523.stl', 44],
|
246 |
-
['./data/
|
247 |
-
['./data/
|
248 |
-
['./data/
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
example_labels=[
|
250 |
-
'fusion360_table1', '
|
251 |
-
'
|
252 |
-
'
|
253 |
-
'
|
|
|
254 |
inputs=[in_model, seed_slider],
|
255 |
cache_examples=False,
|
256 |
examples_per_page=20)
|
@@ -286,7 +284,7 @@ tokenizer = AutoTokenizer.from_pretrained(
|
|
286 |
pad_token='<|im_end|>',
|
287 |
padding_side='left')
|
288 |
cad_recode = CADRecode.from_pretrained(
|
289 |
-
'filapro/cad-recode',
|
290 |
torch_dtype='auto',
|
291 |
attn_implementation='flash_attention_2').eval()
|
292 |
|
|
|
11 |
from transformers import (
|
12 |
AutoTokenizer, Qwen2ForCausalLM, Qwen2Model, PreTrainedModel)
|
13 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
14 |
+
from pytorch3d.ops import sample_farthest_points
|
15 |
|
16 |
|
17 |
class FourierPointEncoder(nn.Module):
|
|
|
19 |
super().__init__()
|
20 |
frequencies = 2.0 ** torch.arange(8, dtype=torch.float32)
|
21 |
self.register_buffer('frequencies', frequencies, persistent=False)
|
22 |
+
self.projection = nn.Linear(51, hidden_size)
|
23 |
|
24 |
def forward(self, points):
|
25 |
+
x = points
|
26 |
x = (x.unsqueeze(-1) * self.frequencies).view(*x.shape[:-1], -1)
|
27 |
+
x = torch.cat((points, x.sin(), x.cos()), dim=-1)
|
28 |
+
x = self.projection(x)
|
29 |
return x
|
30 |
|
31 |
|
|
|
114 |
return model_inputs
|
115 |
|
116 |
|
117 |
+
def mesh_to_point_cloud(mesh, n_points=256, n_pre_points=8192):
|
118 |
+
vertices, _ = trimesh.sample.sample_surface(mesh, n_pre_points)
|
119 |
+
_, ids = sample_farthest_points(torch.tensor(vertices).unsqueeze(0), K=n_points)
|
120 |
+
ids = ids[0].numpy()
|
121 |
+
return np.asarray(vertices[ids])
|
|
|
|
|
|
|
|
|
122 |
|
123 |
|
124 |
def py_string_to_mesh_file(py_string, mesh_path, queue):
|
|
|
216 |
with gr.Row(equal_height=True):
|
217 |
in_model = gr.Model3D(label='1. Input Mesh', interactive=True)
|
218 |
point_model = gr.Model3D(label='2. Sampled Point Cloud', display_mode='point_cloud', interactive=False)
|
219 |
+
out_model = gr.Model3D(label='4. Result CAD Model', interactive=False)
|
|
|
|
|
220 |
|
221 |
with gr.Row():
|
222 |
with gr.Column():
|
|
|
225 |
with gr.Row():
|
226 |
gr.Examples(
|
227 |
examples=[
|
228 |
+
['./data/49215_5368e45e_0000.stl', 46],
|
229 |
+
['./data/User Library-engrenage.stl', 46],
|
|
|
230 |
['./data/00010900.stl', 42],
|
231 |
+
['./data/00375556.stl', 44],
|
232 |
+
['./data/41473_c2137170_0023.stl', 47],
|
233 |
+
['./data/22447_4062c6cb_0011.stl', 43],
|
|
|
|
|
234 |
['./data/27694_7801dc67_0017.stl', 45],
|
235 |
+
['./data/49562_6df35938_0005.stl', 48],
|
|
|
236 |
['./data/00081523.stl', 44],
|
237 |
+
['./data/User Library-Caster, Superior Brand, 600lb 2 EDIT.stl', 57],
|
238 |
+
['./data/00810270.stl', 44],
|
239 |
+
['./data/00295262.stl', 42],
|
240 |
+
['./data/132156_cd3f1428_0000.stl', 59],
|
241 |
+
['./data/User Library-All Weather Padlock_Padlock Key.stl', 66],
|
242 |
+
['./data/00047559.stl', 42],
|
243 |
+
['./data/00052220.stl', 42],
|
244 |
+
['./data/141665_0564e852_0003.stl', 42],
|
245 |
+
['./data/User Library-_50 Cal Round.stl', 42]],
|
246 |
example_labels=[
|
247 |
+
'fusion360_table1', 'cc3d_gear', 'deepcad_barrels', 'deepcad_house',
|
248 |
+
'fusion360_omega', 'fusion360_hat', 'fusion360_table3', 'fusion360_bolt',
|
249 |
+
'deepcad_alpha', 'cc3d_caster', 'deepcad_spinner', 'deepcad_5bricks',
|
250 |
+
'fusion360_star', 'cc3d_key', 'deepcad_3cubes', 'deepcad_arrow',
|
251 |
+
'fusion360_frame', 'cc3d_bullet'],
|
252 |
inputs=[in_model, seed_slider],
|
253 |
cache_examples=False,
|
254 |
examples_per_page=20)
|
|
|
284 |
pad_token='<|im_end|>',
|
285 |
padding_side='left')
|
286 |
cad_recode = CADRecode.from_pretrained(
|
287 |
+
'filapro/cad-recode-v1.5',
|
288 |
torch_dtype='auto',
|
289 |
attn_implementation='flash_attention_2').eval()
|
290 |
|