bofenghuang commited on
Commit
638230e
·
1 Parent(s): 78c1dc5
Files changed (1) hide show
  1. run_demo_multi_models.py +17 -2
run_demo_multi_models.py CHANGED
@@ -28,10 +28,24 @@ logger = logging.getLogger(__name__)
28
  logger.setLevel(logging.DEBUG)
29
 
30
  device = 0 if torch.cuda.is_available() else "cpu"
31
- logger.info(f"Model will be loaded on device {device}")
32
 
33
  cached_models = {}
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def maybe_load_cached_pipeline(model_name):
36
  pipe = cached_models.get(model_name)
37
  if pipe is None:
@@ -49,6 +63,7 @@ def maybe_load_cached_pipeline(model_name):
49
  pipe.model.config.max_length = MAX_NEW_TOKENS + 1
50
 
51
  logger.info(f"`{model_name}` pipeline has been initialized")
 
52
 
53
  cached_models[model_name] = pipe
54
  return pipe
@@ -70,7 +85,7 @@ def transcribe(microphone, file_upload, model_name):
70
  pipe = maybe_load_cached_pipeline(model_name)
71
  text = pipe(file)["text"]
72
 
73
- logger.info(f"Transcription: {text}")
74
 
75
  return warn_output + text
76
 
 
28
  logger.setLevel(logging.DEBUG)
29
 
30
  device = 0 if torch.cuda.is_available() else "cpu"
31
+ logger.info(f"Model will be loaded on device `{device}`")
32
 
33
  cached_models = {}
34
 
35
+
36
+ def print_cuda_memory_info():
37
+ used_mem, tot_mem = torch.cuda.mem_get_info()
38
+ logger.info(f"CUDA memory info - Free: {used_mem / 1024 ** 3:.2f} Gb, used: {(tot_mem - used_mem) / 1024 ** 3:.2f} Gb, total: {tot_mem / 1024 ** 3:.2f} Gb")
39
+
40
+
41
+ def print_memory_info():
42
+ # todo
43
+ if device == "cpu":
44
+ pass
45
+ else:
46
+ print_cuda_memory_info()
47
+
48
+
49
  def maybe_load_cached_pipeline(model_name):
50
  pipe = cached_models.get(model_name)
51
  if pipe is None:
 
63
  pipe.model.config.max_length = MAX_NEW_TOKENS + 1
64
 
65
  logger.info(f"`{model_name}` pipeline has been initialized")
66
+ print_memory_info()
67
 
68
  cached_models[model_name] = pipe
69
  return pipe
 
85
  pipe = maybe_load_cached_pipeline(model_name)
86
  text = pipe(file)["text"]
87
 
88
+ logger.info(f"Transcription by `{model_name}`: {text}")
89
 
90
  return warn_output + text
91