filapro commited on
Commit
ff658ec
·
verified ·
1 Parent(s): c7ad19a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -34
app.py CHANGED
@@ -11,6 +11,7 @@ from torch import nn
11
  from transformers import (
12
  AutoTokenizer, Qwen2ForCausalLM, Qwen2Model, PreTrainedModel)
13
  from transformers.modeling_outputs import CausalLMOutputWithPast
 
14
 
15
 
16
  class FourierPointEncoder(nn.Module):
@@ -18,13 +19,13 @@ class FourierPointEncoder(nn.Module):
18
  super().__init__()
19
  frequencies = 2.0 ** torch.arange(8, dtype=torch.float32)
20
  self.register_buffer('frequencies', frequencies, persistent=False)
21
- self.projection = nn.Linear(54, hidden_size)
22
 
23
  def forward(self, points):
24
- x = points[..., :3]
25
  x = (x.unsqueeze(-1) * self.frequencies).view(*x.shape[:-1], -1)
26
- x = torch.cat((points[..., :3], x.sin(), x.cos()), dim=-1)
27
- x = self.projection(torch.cat((x, points[..., 3:]), dim=-1))
28
  return x
29
 
30
 
@@ -113,15 +114,11 @@ class CADRecode(Qwen2ForCausalLM):
113
  return model_inputs
114
 
115
 
116
- def mesh_to_point_cloud(mesh, n_points=256):
117
- vertices, faces = trimesh.sample.sample_surface(mesh, n_points)
118
- point_cloud = np.concatenate((
119
- np.asarray(vertices),
120
- mesh.face_normals[faces]
121
- ), axis=1)
122
- ids = np.lexsort((point_cloud[:, 0], point_cloud[:, 1], point_cloud[:, 2]))
123
- point_cloud = point_cloud[ids]
124
- return point_cloud
125
 
126
 
127
  def py_string_to_mesh_file(py_string, mesh_path, queue):
@@ -219,9 +216,7 @@ def run():
219
  with gr.Row(equal_height=True):
220
  in_model = gr.Model3D(label='1. Input Mesh', interactive=True)
221
  point_model = gr.Model3D(label='2. Sampled Point Cloud', display_mode='point_cloud', interactive=False)
222
- out_model = gr.Model3D(
223
- label='4. Result CAD Model', interactive=False
224
- )
225
 
226
  with gr.Row():
227
  with gr.Column():
@@ -230,27 +225,30 @@ def run():
230
  with gr.Row():
231
  gr.Examples(
232
  examples=[
233
- ['./data/49215_5368e45e_0000.stl', 42],
234
- ['./data/00882236.stl', 6],
235
- ['./data/User Library-engrenage.stl', 18],
236
  ['./data/00010900.stl', 42],
237
- ['./data/21492_8bd34fc1_0008.stl', 42],
238
- ['./data/00375556.stl', 96],
239
- ['./data/49121_adb01620_0000.stl', 42],
240
- ['./data/41473_c2137170_0023.stl', 42],
241
- ['./data/22447_4062c6cb_0011.stl', 42],
242
  ['./data/27694_7801dc67_0017.stl', 45],
243
- ['./data/49562_6df35938_0005.stl', 44],
244
- ['./data/131709_8b86dfb6_0000.stl', 45],
245
  ['./data/00081523.stl', 44],
246
- ['./data/00614972.stl', 44],
247
- ['./data/User Library-Cople Ventosa V2.stl', 50],
248
- ['./data/User Library-Caster, Superior Brand, 600lb 2 EDIT.stl', 52]],
 
 
 
 
 
 
249
  example_labels=[
250
- 'fusion360_table1', 'deepcad_star', 'cc3d_gear', 'deepcad_barrels',
251
- 'fusion360_gear', 'deepcad_house', 'fusion360_table2', 'fusion360_omega',
252
- 'fusion360_hat', 'fusion360_table3', 'fusion360_bolt', 'fusion360_clamp',
253
- 'deepcad_alpha', 'deepcad_pulley', 'cc3d_flange', 'cc3d_caster'],
 
254
  inputs=[in_model, seed_slider],
255
  cache_examples=False,
256
  examples_per_page=20)
@@ -286,7 +284,7 @@ tokenizer = AutoTokenizer.from_pretrained(
286
  pad_token='<|im_end|>',
287
  padding_side='left')
288
  cad_recode = CADRecode.from_pretrained(
289
- 'filapro/cad-recode',
290
  torch_dtype='auto',
291
  attn_implementation='flash_attention_2').eval()
292
 
 
11
  from transformers import (
12
  AutoTokenizer, Qwen2ForCausalLM, Qwen2Model, PreTrainedModel)
13
  from transformers.modeling_outputs import CausalLMOutputWithPast
14
+ from pytorch3d.ops import sample_farthest_points
15
 
16
 
17
  class FourierPointEncoder(nn.Module):
 
19
  super().__init__()
20
  frequencies = 2.0 ** torch.arange(8, dtype=torch.float32)
21
  self.register_buffer('frequencies', frequencies, persistent=False)
22
+ self.projection = nn.Linear(51, hidden_size)
23
 
24
  def forward(self, points):
25
+ x = points
26
  x = (x.unsqueeze(-1) * self.frequencies).view(*x.shape[:-1], -1)
27
+ x = torch.cat((points, x.sin(), x.cos()), dim=-1)
28
+ x = self.projection(x)
29
  return x
30
 
31
 
 
114
  return model_inputs
115
 
116
 
117
+ def mesh_to_point_cloud(mesh, n_points=256, n_pre_points=8192):
118
+ vertices, _ = trimesh.sample.sample_surface(mesh, n_pre_points)
119
+ _, ids = sample_farthest_points(torch.tensor(vertices).unsqueeze(0), K=n_points)
120
+ ids = ids[0].numpy()
121
+ return np.asarray(vertices[ids])
 
 
 
 
122
 
123
 
124
  def py_string_to_mesh_file(py_string, mesh_path, queue):
 
216
  with gr.Row(equal_height=True):
217
  in_model = gr.Model3D(label='1. Input Mesh', interactive=True)
218
  point_model = gr.Model3D(label='2. Sampled Point Cloud', display_mode='point_cloud', interactive=False)
219
+ out_model = gr.Model3D(label='4. Result CAD Model', interactive=False)
 
 
220
 
221
  with gr.Row():
222
  with gr.Column():
 
225
  with gr.Row():
226
  gr.Examples(
227
  examples=[
228
+ ['./data/49215_5368e45e_0000.stl', 46],
229
+ ['./data/User Library-engrenage.stl', 46],
 
230
  ['./data/00010900.stl', 42],
231
+ ['./data/00375556.stl', 44],
232
+ ['./data/41473_c2137170_0023.stl', 47],
233
+ ['./data/22447_4062c6cb_0011.stl', 43],
 
 
234
  ['./data/27694_7801dc67_0017.stl', 45],
235
+ ['./data/49562_6df35938_0005.stl', 48],
 
236
  ['./data/00081523.stl', 44],
237
+ ['./data/User Library-Caster, Superior Brand, 600lb 2 EDIT.stl', 57],
238
+ ['./data/00810270.stl', 44],
239
+ ['./data/00295262.stl', 42],
240
+ ['./data/132156_cd3f1428_0000.stl', 59],
241
+ ['./data/User Library-All Weather Padlock_Padlock Key.stl', 66],
242
+ ['./data/00047559.stl', 42],
243
+ ['./data/00052220.stl', 42],
244
+ ['./data/141665_0564e852_0003.stl', 42],
245
+ ['./data/User Library-_50 Cal Round.stl', 42]],
246
  example_labels=[
247
+ 'fusion360_table1', 'cc3d_gear', 'deepcad_barrels', 'deepcad_house',
248
+ 'fusion360_omega', 'fusion360_hat', 'fusion360_table3', 'fusion360_bolt',
249
+ 'deepcad_alpha', 'cc3d_caster', 'deepcad_spinner', 'deepcad_5bricks',
250
+ 'fusion360_star', 'cc3d_key', 'deepcad_3cubes', 'deepcad_arrow',
251
+ 'fusion360_frame', 'cc3d_bullet'],
252
  inputs=[in_model, seed_slider],
253
  cache_examples=False,
254
  examples_per_page=20)
 
284
  pad_token='<|im_end|>',
285
  padding_side='left')
286
  cad_recode = CADRecode.from_pretrained(
287
+ 'filapro/cad-recode-v1.5',
288
  torch_dtype='auto',
289
  attn_implementation='flash_attention_2').eval()
290