02alexander commited on
Commit
fd49e19
1 Parent(s): 34d9a52

restructure code

Browse files
Files changed (1) hide show
  1. app.py +84 -72
app.py CHANGED
@@ -154,61 +154,66 @@ def preprocess(input_image, do_remove_background):
154
  return input_image
155
 
156
 
157
- def pipeline_callback(output_queue: SimpleQueue, pipe: Any, step_index: int, timestep: float, callback_kwargs: dict[str, Any]) -> dict[str, Any]:
158
  latents = callback_kwargs["latents"]
159
  image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] # type: ignore[attr-defined]
160
  image = pipe.image_processor.postprocess(image, output_type="np").squeeze() # type: ignore[attr-defined]
161
 
162
- output_queue.put(("log", "mvs/image", rr.Image(image)))
163
- output_queue.put(("log", "mvs/latents", rr.Tensor(latents.squeeze())))
164
 
165
  return callback_kwargs
166
 
167
- def generate_mvs(input_image, sample_steps, sample_seed):
168
 
169
  seed_everything(sample_seed)
170
 
171
-
172
- def thread_target(output_queue, input_image, sample_steps):
173
- z123_image = pipeline(
174
- input_image,
175
- num_inference_steps=sample_steps,
176
- callback_on_step_end=lambda *args, **kwargs: pipeline_callback(output_queue, *args, **kwargs),
177
- ).images[0]
178
- output_queue.put(("z123_image", z123_image))
179
-
180
-
181
- output_queue = SimpleQueue()
182
- z123_thread = threading.Thread(
183
- target=thread_target,
184
- args=
185
- [
186
- output_queue,
187
- input_image,
188
- sample_steps,
189
- ]
190
- )
191
- z123_thread.start()
192
-
193
- while True:
194
- msg = output_queue.get()
195
- yield msg
196
- if msg[0] == "z123_image":
197
- break
198
- z123_thread.join()
199
-
200
- def make3d(images: Image.Image):
201
- output_queue = SimpleQueue()
202
- handle = threading.Thread(target=_make3d, args=[output_queue, images])
203
- handle.start()
204
- while True:
205
- msg = output_queue.get()
206
- yield msg
207
- if msg[0] == "mesh":
208
- break
209
- handle.join()
210
-
211
- def _make3d(output_queue: SimpleQueue, images: Image.Image):
 
 
 
 
 
212
  global model
213
  if IS_FLEXICUBES:
214
  model.init_flexicubes_geometry(device, use_renderer=False)
@@ -238,9 +243,8 @@ def _make3d(output_queue: SimpleQueue, images: Image.Image):
238
 
239
  vertices, faces, vertex_colors = mesh_out
240
 
241
- output_queue.put(
242
  (
243
- "log",
244
  "mesh",
245
  rr.Mesh3D(
246
  vertex_positions=vertices,
@@ -249,7 +253,8 @@ def _make3d(output_queue: SimpleQueue, images: Image.Image):
249
  ),
250
  )
251
  )
252
- output_queue.put(("mesh", mesh_out))
 
253
 
254
  def generate_blueprint() -> rrb.Blueprint:
255
  return rrb.Blueprint(
@@ -258,49 +263,56 @@ def generate_blueprint() -> rrb.Blueprint:
258
  rrb.Grid(
259
  rrb.Spatial2DView(origin="z123image"),
260
  rrb.Spatial2DView(origin="preprocessed_image"),
261
- rrb.Spatial2DView(origin="mvs/image"),
262
- rrb.TensorView(origin="mvs/latents", ),
263
  ),
264
  column_shares=[1, 1],
265
  ),
 
266
  collapse_panels=True,
267
  )
268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  @spaces.GPU
270
  @rr.thread_local_stream("InstantMesh")
271
  def log_to_rr(input_image, do_remove_background, sample_steps, sample_seed):
272
 
 
 
273
  stream = rr.binary_stream()
274
 
275
  blueprint = generate_blueprint()
276
  rr.send_blueprint(blueprint)
277
  yield stream.read()
278
 
279
- preprocessed_image = preprocess(input_image, do_remove_background)
280
- rr.log("preprocessed_image", rr.Image(preprocessed_image))
281
-
282
- yield stream.read()
283
-
284
- for msg in generate_mvs(preprocessed_image, sample_steps, sample_seed):
285
- if msg[0] == "z123_image":
286
- z123_image = msg[1]
287
  break
288
- elif msg[0] == "log":
289
- entity_path = msg[1]
290
- entity = msg[2]
291
  rr.log(entity_path, entity)
292
  yield stream.read()
293
-
294
- rr.log("z123image", rr.Image(z123_image))
295
- yield stream.read()
296
-
297
- for msg in make3d(z123_image):
298
- if msg[0] == "log":
299
- rr.log(msg[1], msg[2])
300
- yield stream.read()
301
- if msg[0] == "mesh":
302
- mesh = msg[1]
303
-
304
  # return mesh
305
 
306
  _HEADER_ = '''
 
154
  return input_image
155
 
156
 
157
+ def pipeline_callback(log_queue: SimpleQueue, pipe: Any, step_index: int, timestep: float, callback_kwargs: dict[str, Any]) -> dict[str, Any]:
158
  latents = callback_kwargs["latents"]
159
  image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] # type: ignore[attr-defined]
160
  image = pipe.image_processor.postprocess(image, output_type="np").squeeze() # type: ignore[attr-defined]
161
 
162
+ log_queue.put(("mvs", rr.Image(image)))
163
+ log_queue.put(("latents", rr.Tensor(latents.squeeze())))
164
 
165
  return callback_kwargs
166
 
167
+ def generate_mvs(log_queue, input_image, sample_steps, sample_seed):
168
 
169
  seed_everything(sample_seed)
170
 
171
+ return pipeline(
172
+ input_image,
173
+ num_inference_steps=sample_steps,
174
+ callback_on_step_end=lambda *args, **kwargs: pipeline_callback(log_queue, *args, **kwargs),
175
+ ).images[0]
176
+
177
+ # def thread_target(output_queue, input_image, sample_steps):
178
+ # z123_image = pipeline(
179
+ # input_image,
180
+ # num_inference_steps=sample_steps,
181
+ # callback_on_step_end=lambda *args, **kwargs: pipeline_callback(output_queue, *args, **kwargs),
182
+ # ).images[0]
183
+ # log_queue.put(("z123_image", z123_image))
184
+
185
+
186
+ # output_queue = SimpleQueue()
187
+ # z123_thread = threading.Thread(
188
+ # target=thread_target,
189
+ # args=
190
+ # [
191
+ # output_queue,
192
+ # input_image,
193
+ # sample_steps,
194
+ # ]
195
+ # )
196
+ # z123_thread.start()
197
+
198
+ # while True:
199
+ # msg = output_queue.get()
200
+ # yield msg
201
+ # if msg[0] == "z123_image":
202
+ # break
203
+ # z123_thread.join()
204
+
205
+ # def make3d(images: Image.Image):
206
+ # output_queue = SimpleQueue()
207
+ # handle = threading.Thread(target=_make3d, args=[output_queue, images])
208
+ # handle.start()
209
+ # while True:
210
+ # msg = output_queue.get()
211
+ # yield msg
212
+ # if msg[0] == "mesh":
213
+ # break
214
+ # handle.join()
215
+
216
+ def make3d(log_queue, images: Image.Image):
217
  global model
218
  if IS_FLEXICUBES:
219
  model.init_flexicubes_geometry(device, use_renderer=False)
 
243
 
244
  vertices, faces, vertex_colors = mesh_out
245
 
246
+ log_queue.put(
247
  (
 
248
  "mesh",
249
  rr.Mesh3D(
250
  vertex_positions=vertices,
 
253
  ),
254
  )
255
  )
256
+
257
+ return mesh_out
258
 
259
  def generate_blueprint() -> rrb.Blueprint:
260
  return rrb.Blueprint(
 
263
  rrb.Grid(
264
  rrb.Spatial2DView(origin="z123image"),
265
  rrb.Spatial2DView(origin="preprocessed_image"),
266
+ rrb.Spatial2DView(origin="mvs"),
267
+ rrb.TensorView(origin="latents", ),
268
  ),
269
  column_shares=[1, 1],
270
  ),
271
+
272
  collapse_panels=True,
273
  )
274
 
275
+ def compute(log_queue, input_image, do_remove_background, sample_steps, sample_seed):
276
+
277
+ preprocessed_image = preprocess(input_image, do_remove_background)
278
+
279
+ log_queue.put(("preprocessed_image", rr.Image(preprocessed_image)))
280
+ # rr.log("preprocessed_image", rr.Image(preprocessed_image))
281
+
282
+ z123_image = generate_mvs(log_queue, preprocessed_image, sample_steps, sample_seed)
283
+
284
+ log_queue.put(("z123image", rr.Image(z123_image)))
285
+ # rr.log("z123image", rr.Image(z123_image))
286
+
287
+ mesh_out = make3d(log_queue, z123_image)
288
+
289
+ log_queue.put("done")
290
+
291
+
292
  @spaces.GPU
293
  @rr.thread_local_stream("InstantMesh")
294
  def log_to_rr(input_image, do_remove_background, sample_steps, sample_seed):
295
 
296
+ log_queue = SimpleQueue()
297
+
298
  stream = rr.binary_stream()
299
 
300
  blueprint = generate_blueprint()
301
  rr.send_blueprint(blueprint)
302
  yield stream.read()
303
 
304
+ handle = threading.Thread(target=compute, args=[log_queue, input_image, do_remove_background, sample_steps, sample_seed])
305
+ handle.start()
306
+ while True:
307
+ msg = log_queue.get()
308
+ if msg == "done":
 
 
 
309
  break
310
+ else:
311
+ entity_path, entity = msg
 
312
  rr.log(entity_path, entity)
313
  yield stream.read()
314
+ handle.join()
315
+
 
 
 
 
 
 
 
 
 
316
  # return mesh
317
 
318
  _HEADER_ = '''