Spaces:
Sleeping
Sleeping
bofenghuang
commited on
Commit
·
638230e
1
Parent(s):
78c1dc5
up
Browse files- run_demo_multi_models.py +17 -2
run_demo_multi_models.py
CHANGED
@@ -28,10 +28,24 @@ logger = logging.getLogger(__name__)
|
|
28 |
logger.setLevel(logging.DEBUG)
|
29 |
|
30 |
device = 0 if torch.cuda.is_available() else "cpu"
|
31 |
-
logger.info(f"Model will be loaded on device {device}")
|
32 |
|
33 |
cached_models = {}
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
def maybe_load_cached_pipeline(model_name):
|
36 |
pipe = cached_models.get(model_name)
|
37 |
if pipe is None:
|
@@ -49,6 +63,7 @@ def maybe_load_cached_pipeline(model_name):
|
|
49 |
pipe.model.config.max_length = MAX_NEW_TOKENS + 1
|
50 |
|
51 |
logger.info(f"`{model_name}` pipeline has been initialized")
|
|
|
52 |
|
53 |
cached_models[model_name] = pipe
|
54 |
return pipe
|
@@ -70,7 +85,7 @@ def transcribe(microphone, file_upload, model_name):
|
|
70 |
pipe = maybe_load_cached_pipeline(model_name)
|
71 |
text = pipe(file)["text"]
|
72 |
|
73 |
-
logger.info(f"Transcription
|
74 |
|
75 |
return warn_output + text
|
76 |
|
|
|
28 |
logger.setLevel(logging.DEBUG)
|
29 |
|
30 |
device = 0 if torch.cuda.is_available() else "cpu"
|
31 |
+
logger.info(f"Model will be loaded on device `{device}`")
|
32 |
|
33 |
cached_models = {}
|
34 |
|
35 |
+
|
36 |
+
def print_cuda_memory_info():
|
37 |
+
used_mem, tot_mem = torch.cuda.mem_get_info()
|
38 |
+
logger.info(f"CUDA memory info - Free: {used_mem / 1024 ** 3:.2f} Gb, used: {(tot_mem - used_mem) / 1024 ** 3:.2f} Gb, total: {tot_mem / 1024 ** 3:.2f} Gb")
|
39 |
+
|
40 |
+
|
41 |
+
def print_memory_info():
|
42 |
+
# todo
|
43 |
+
if device == "cpu":
|
44 |
+
pass
|
45 |
+
else:
|
46 |
+
print_cuda_memory_info()
|
47 |
+
|
48 |
+
|
49 |
def maybe_load_cached_pipeline(model_name):
|
50 |
pipe = cached_models.get(model_name)
|
51 |
if pipe is None:
|
|
|
63 |
pipe.model.config.max_length = MAX_NEW_TOKENS + 1
|
64 |
|
65 |
logger.info(f"`{model_name}` pipeline has been initialized")
|
66 |
+
print_memory_info()
|
67 |
|
68 |
cached_models[model_name] = pipe
|
69 |
return pipe
|
|
|
85 |
pipe = maybe_load_cached_pipeline(model_name)
|
86 |
text = pipe(file)["text"]
|
87 |
|
88 |
+
logger.info(f"Transcription by `{model_name}`: {text}")
|
89 |
|
90 |
return warn_output + text
|
91 |
|