cbensimon HF staff commited on
Commit
53afe62
1 Parent(s): ab39f0c

Update spaces/zero/torch.py

Browse files
Files changed (1) hide show
  1. spaces/zero/torch.py +14 -0
spaces/zero/torch.py CHANGED
@@ -6,6 +6,7 @@ from __future__ import annotations
6
 
7
  import multiprocessing
8
  import os
 
9
  from concurrent.futures import ProcessPoolExecutor
10
  from contextlib import suppress
11
  from functools import partial
@@ -241,8 +242,12 @@ if (torch := maybe_import_torch()):
241
  bitsandbytes.unpatch()
242
 
243
  def _move(nvidia_uuid: str):
 
244
  os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
245
  torch.Tensor([0]).cuda() # CUDA init
 
 
 
246
  for op in to_ops.items():
247
  tensor, parsed_args = op
248
  _, dtype, _, memory_format = parsed_args
@@ -251,8 +256,17 @@ if (torch := maybe_import_torch()):
251
  dtype=dtype,
252
  memory_format=memory_format,
253
  ) # type: ignore
 
 
 
254
  bitsandbytes.move()
 
 
 
255
  torch.cuda.synchronize()
 
 
 
256
 
257
  def _is_in_bad_fork():
258
  with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
 
6
 
7
  import multiprocessing
8
  import os
9
+ import time
10
  from concurrent.futures import ProcessPoolExecutor
11
  from contextlib import suppress
12
  from functools import partial
 
242
  bitsandbytes.unpatch()
243
 
244
  def _move(nvidia_uuid: str):
245
+ t0 = time.perf_counter()
246
  os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
247
  torch.Tensor([0]).cuda() # CUDA init
248
+ t1 = time.perf_counter()
249
+ print("CUDA init", t1 - t0)
250
+ t0 = t1
251
  for op in to_ops.items():
252
  tensor, parsed_args = op
253
  _, dtype, _, memory_format = parsed_args
 
256
  dtype=dtype,
257
  memory_format=memory_format,
258
  ) # type: ignore
259
+ t1 = time.perf_counter()
260
+ print("CUDA move", t1 - t0)
261
+ t0 = t1
262
  bitsandbytes.move()
263
+ t1 = time.perf_counter()
264
+ print("BNB move", t1 - t0)
265
+ t0 = t1
266
  torch.cuda.synchronize()
267
+ t1 = time.perf_counter()
268
+ print("CUDA synchronize", t1 - t0)
269
+ t0 = t1
270
 
271
  def _is_in_bad_fork():
272
  with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e: