Spaces:
Sleeping
Sleeping
Yuhan-Lu
commited on
Commit
·
e90d25c
1
Parent(s):
d231d79
fix logging path; enable CUDA for whispher
Browse filesFormer-commit-id: 0ea328b0a70e95f6061b93083ed418eafa4857c8
- pipeline.py +8 -2
pipeline.py
CHANGED
@@ -9,6 +9,7 @@ import whisper
|
|
9 |
from srt2ass import srt2ass
|
10 |
import logging
|
11 |
from datetime import datetime
|
|
|
12 |
|
13 |
import subprocess
|
14 |
|
@@ -109,7 +110,10 @@ def get_srt_class(srt_file_en, result_path, video_name, audio_path, audio_file =
|
|
109 |
|
110 |
# use stable-whisper
|
111 |
elif method == "stable":
|
112 |
-
|
|
|
|
|
|
|
113 |
transcript = model.transcribe(audio_path, regroup = False, initial_prompt="Hello, welcome to my lecture. Are you good my friend?")
|
114 |
(
|
115 |
transcript
|
@@ -265,7 +269,9 @@ def main():
|
|
265 |
|
266 |
audio_path, audio_file, video_path, VIDEO_NAME = get_sources(args, DOWNLOAD_PATH, RESULT_PATH, VIDEO_NAME)
|
267 |
|
268 |
-
|
|
|
|
|
269 |
logging.info("---------------------Video Info---------------------")
|
270 |
logging.info("Video name: {}, translation model: {}, video link: {}".format(VIDEO_NAME, args.model_name, args.link))
|
271 |
|
|
|
9 |
from srt2ass import srt2ass
|
10 |
import logging
|
11 |
from datetime import datetime
|
12 |
+
import torch
|
13 |
|
14 |
import subprocess
|
15 |
|
|
|
110 |
|
111 |
# use stable-whisper
|
112 |
elif method == "stable":
|
113 |
+
|
114 |
+
# use cuda if available
|
115 |
+
devices = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
116 |
+
model = stable_whisper.load_model(whisper_model, device = devices)
|
117 |
transcript = model.transcribe(audio_path, regroup = False, initial_prompt="Hello, welcome to my lecture. Are you good my friend?")
|
118 |
(
|
119 |
transcript
|
|
|
269 |
|
270 |
audio_path, audio_file, video_path, VIDEO_NAME = get_sources(args, DOWNLOAD_PATH, RESULT_PATH, VIDEO_NAME)
|
271 |
|
272 |
+
if not os.path.exists(args.log_dir):
|
273 |
+
os.makedirs(args.log_dir)
|
274 |
+
logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler("{}/{}_{}.log".format(args.log_dir, VIDEO_NAME, datetime.now().strftime("%m%d%Y_%H%M%S")), 'w', encoding='utf-8')])
|
275 |
logging.info("---------------------Video Info---------------------")
|
276 |
logging.info("Video name: {}, translation model: {}, video link: {}".format(VIDEO_NAME, args.model_name, args.link))
|
277 |
|