File size: 5,155 Bytes
3f9bd99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import whisper
import os
import json
import torchaudio
import argparse
import torch
from config import config
lang2token = {
            'zh': "ZH|",
            'ja': "JP|",
            "en": "EN|",
        }
def transcribe_one(audio_path):
    # load audio and pad/trim it to fit 30 seconds
    audio = whisper.load_audio(audio_path)
    audio = whisper.pad_or_trim(audio)

    # make log-Mel spectrogram and move to the same device as the model
    try:
        mel = whisper.log_mel_spectrogram(audio).to(model.device)
        _, probs = model.detect_language(mel)
    except:
        mel = whisper.log_mel_spectrogram(audio=audio, n_mels=128).to(model.device)
        _, probs = model.detect_language(mel)

    # detect the spoken language
    
    print(f"Detected language: {max(probs, key=probs.get)}")
    lang = max(probs, key=probs.get)
    # decode the audio
    options = whisper.DecodingOptions(beam_size=5)
    result = whisper.decode(model, mel, options)

    # print the recognized text
    print(result.text)
    return lang, result.text
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--languages", default="CJ")
    parser.add_argument("--whisper_size", default="medium")
    args = parser.parse_args()
    if args.languages == "CJE":
        lang2token = {
            'zh': "ZH|",
            'ja': "JP|",
            "en": "EN|",
        }
    elif args.languages == "CJ":
        lang2token = {
            'zh': "ZH|",
            'ja': "JP|",
        }
    elif args.languages == "C":
        lang2token = {
            'zh': "ZH|",
        }
    assert (torch.cuda.is_available()), "Please enable GPU in order to run Whisper!"
    #model = whisper.load_model(args.whisper_size)
    model = whisper.load_model(args.whisper_size, download_root = ".\\whisper_model")
    #parent_dir = "./custom_character_voice/"
    parent_dir=config.resample_config.in_dir
    print(parent_dir)
    speaker_names = list(os.walk(parent_dir))[0][1]
    speaker_annos = []
    total_files = sum([len(files) for r, d, files in os.walk(parent_dir)])
    with open(config.train_ms_config.config_path,'r', encoding='utf-8') as f:
        hps = json.load(f)
    target_sr = hps['data']['sampling_rate']
    processed_files = 0
    for speaker in speaker_names:
        for i, wavfile in enumerate(list(os.walk(os.path.join(parent_dir,speaker)))[0][2]):
            # try to load file as audio
            if wavfile.startswith("processed_"):
                continue
            try:
                save_path = parent_dir+"/"+ speaker + "/" + f"processed_{wavfile}"
                lab_path = parent_dir+"/"+ speaker + "/" + f"processed_{os.path.splitext(wavfile)[0]}.lab"
                wav_path =parent_dir + "/" + speaker + "/" + wavfile
                if not os.path.exists(save_path):                
                    processed=True
                    wav, sr = torchaudio.load(wav_path, frame_offset=0, num_frames=-1, normalize=True,
                                          channels_first=True)
                    wav = wav.mean(dim=0).unsqueeze(0)
                    if sr != target_sr:
                        wav = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(wav)
                    if wav.shape[1] / sr > 20:
                        print(f"warning: {wavfile} too long\n")
                    torchaudio.save(save_path, wav, target_sr, channels_first=True)
                else:
                   processed=False

                # transcribe text
                try:
                    with open((lab_path), "r", encoding="utf-8") as f:
                        text=f.read()
                    assert '|' in text  
                    print("[进度恢复]: "+lab_path+"已找到并已经成功读取") 
                except Exception as e:
                    if not processed:
                        print("[进度恢复]: "+lab_path+"未找到或读取错误"+str(e))
                    lang, text = transcribe_one(save_path)
                    if lang not in list(lang2token.keys()):
                        print(f"{lang} not supported, ignoring\n")
                        continue
                #text = "ZH|" + text + "\n"                
                    text = lang2token[lang] + text + "\n"
                    with open((lab_path), "w", encoding="utf-8") as f:
                        f.write(text)
                speaker_annos.append(save_path + "|" + speaker + "|" + text)
                
                
                processed_files += 1
                print(f"Processed: {processed_files}/{total_files}")
            except Exception as e:
                print(e)
                continue
    #end
    if len(speaker_annos) == 0:
        print("Warning: length of speaker_annos == 0")
        print("this IS NOT expected. Please check your file structure , make sure your audio language is supported or check ffmpeg path.")
    else:
        with open(config.preprocess_text_config.transcription_path, 'w', encoding='utf-8') as f:
            for line in speaker_annos:
                f.write(line)
        print("finished")