File size: 8,352 Bytes
33a8656
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import sys
import copy
import librosa
import logging
import argparse
import numpy as np
import soundfile as sf
import moviepy.editor as mpy
# from modelscope.pipelines import pipeline
# from modelscope.utils.constant import Tasks
from subtitle_utils import generate_srt, generate_srt_clip, distribute_spk
from trans_utils import pre_proc, proc, write_state, load_state, proc_spk, generate_vad_data
# from argparse_tools import ArgumentParser, get_commandline_args

from moviepy.editor import *
from moviepy.video.tools.subtitles import SubtitlesClip


class VideoClipper():
    def __init__(self, asr_pipeline, sd_pipeline=None):
        logging.warning("Initializing VideoClipper.")
        self.asr_pipeline = asr_pipeline
        self.sd_pipeline = sd_pipeline

    def recog(self, audio_input, sd_switch='no', state=None):
        if state is None:
            state = {}
        sr, data = audio_input
        assert sr == 16000, "16kHz sample rate required, {} given.".format(sr)
        if len(data.shape) == 2:  # multi-channel wav input
            logging.warning("Input wav shape: {}, only first channel reserved.").format(data.shape)
            data = data[:,0]
        state['audio_input'] = (sr, data)
        data = data.astype(np.float64)
        rec_result = self.asr_pipeline(audio_in=data)
        if sd_switch == 'yes':
            vad_data = generate_vad_data(data.astype(np.float32), rec_result['sentences'], sr)
            sd_result = self.sd_pipeline(audio=vad_data, batch_size=1)
            rec_result['sd_sentences'] = distribute_spk(rec_result['sentences'], sd_result['text'])
            res_srt = generate_srt(rec_result['sd_sentences'])
            state['sd_sentences'] = rec_result['sd_sentences']
        else:
            res_srt = generate_srt(rec_result['sentences'])
        state['recog_res_raw'] = rec_result['text_postprocessed']
        state['timestamp'] = rec_result['time_stamp']
        state['sentences'] = rec_result['sentences']
        res_text = rec_result['text']
        return res_text, res_srt, state

    def clip(self, dest_text, start_ost, end_ost, state, dest_spk=None):
        # get from state
        audio_input = state['audio_input']
        recog_res_raw = state['recog_res_raw']
        timestamp = state['timestamp']
        sentences = state['sentences']
        sr, data = audio_input
        data = data.astype(np.float64)

        all_ts = []
        if dest_spk is None or dest_spk == '' or 'sd_sentences' not in state:
            for _dest_text in dest_text.split('#'):
                _dest_text = pre_proc(_dest_text)
                ts = proc(recog_res_raw, timestamp, _dest_text)
                for _ts in ts: all_ts.append(_ts)
        else:
            for _dest_spk in dest_spk.split('#'):
                ts = proc_spk(_dest_spk, state['sd_sentences'])
                for _ts in ts: all_ts.append(_ts)
        ts = all_ts
        ts.sort()
        srt_index = 0
        clip_srt = ""
        if len(ts):
            start, end = ts[0]
            start = min(max(0, start+start_ost*16), len(data))
            end = min(max(0, end+end_ost*16), len(data))
            res_audio = data[start:end]
            start_end_info = "from {} to {}".format(start/16000, end/16000)
            srt_clip, _, srt_index = generate_srt_clip(sentences, start/16000.0, end/16000.0, begin_index=srt_index)
            clip_srt += srt_clip
            for _ts in ts[1:]:  # multiple sentence input or multiple output matched
                start, end = _ts
                start = min(max(0, start+start_ost*16), len(data))
                end = min(max(0, end+end_ost*16), len(data))
                start_end_info += ", from {} to {}".format(start, end)
                res_audio = np.concatenate([res_audio, data[start+start_ost*16:end+end_ost*16]], -1)
                srt_clip, _, srt_index = generate_srt_clip(sentences, start/16000.0, end/16000.0, begin_index=srt_index-1)
                clip_srt += srt_clip
        if len(ts):
            message = "{} periods found in the speech: ".format(len(ts)) + start_end_info
        else:
            message = "No period found in the speech, return raw speech. You may check the recognition result and try other destination text."
            res_audio = data
        return (sr, res_audio), message, clip_srt

    def video_recog(self, vedio_filename, sd_switch='no'):
        vedio_filename = vedio_filename
        clip_video_file = vedio_filename[:-4] + '_clip.mp4'
        video = mpy.VideoFileClip(vedio_filename)
        audio_file = vedio_filename[:-3] + 'wav'
        video.audio.write_audiofile(audio_file)
        wav = librosa.load(audio_file, sr=16000)[0]
        state = {
            'vedio_filename': vedio_filename,
            'clip_video_file': clip_video_file,
            'video': video,
        }
        # res_text, res_srt = self.recog((16000, wav), state)
        return self.recog((16000, wav), sd_switch, state)

    def video_clip(self, dest_text, start_ost, end_ost, state, font_size=32, font_color='white', add_sub=False, dest_spk=None):
        # get from state
        recog_res_raw = state['recog_res_raw']
        timestamp = state['timestamp']
        sentences = state['sentences']
        video = state['video']
        clip_video_file = state['clip_video_file']
        vedio_filename = state['vedio_filename']
        
        all_ts = []
        srt_index = 0
        if dest_spk is None or dest_spk == '' or 'sd_sentences' not in state:
            for _dest_text in dest_text.split('#'):
                _dest_text = pre_proc(_dest_text)
                ts = proc(recog_res_raw, timestamp, _dest_text)
                for _ts in ts: all_ts.append(_ts)
        else:
            for _dest_spk in dest_spk.split('#'):
                ts = proc_spk(_dest_spk, state['sd_sentences'])
                for _ts in ts: all_ts.append(_ts)
        time_acc_ost = 0.0
        ts = all_ts
        ts.sort()
        clip_srt = ""
        if len(ts):
            start, end = ts[0][0] / 16000, ts[0][1] / 16000
            srt_clip, subs, srt_index = generate_srt_clip(sentences, start, end, begin_index=srt_index, time_acc_ost=time_acc_ost)
            start, end = start+start_ost/1000.0, end+end_ost/1000.0
            video_clip = video.subclip(start, end)
            start_end_info = "from {} to {}".format(start, end)
            clip_srt += srt_clip
            if add_sub:
                generator = lambda txt: TextClip(txt, font='./font/STHeitiMedium.ttc', fontsize=font_size, color=font_color)
                subtitles = SubtitlesClip(subs, generator)
                video_clip = CompositeVideoClip([video_clip, subtitles.set_pos(('center','bottom'))])
            concate_clip = [video_clip]
            time_acc_ost += end+end_ost/1000.0 - (start+start_ost/1000.0)
            for _ts in ts[1:]:
                start, end = _ts[0] / 16000, _ts[1] / 16000
                srt_clip, subs, srt_index = generate_srt_clip(sentences, start, end, begin_index=srt_index-1, time_acc_ost=time_acc_ost)
                start, end = start+start_ost/1000.0, end+end_ost/1000.0
                _video_clip = video.subclip(start, end)
                start_end_info += ", from {} to {}".format(start, end)
                clip_srt += srt_clip
                if add_sub:
                    generator = lambda txt: TextClip(txt, font='./font/STHeitiMedium.ttc', fontsize=font_size, color=font_color)
                    subtitles = SubtitlesClip(subs, generator)
                    _video_clip = CompositeVideoClip([_video_clip, subtitles.set_pos(('center','bottom'))])
                concate_clip.append(copy.copy(_video_clip))
                time_acc_ost += end+end_ost/1000.0 - (start+start_ost/1000.0)
            message = "{} periods found in the audio: ".format(len(ts)) + start_end_info
            logging.warning("Concating...")
            if len(concate_clip) > 1:
                video_clip = concatenate_videoclips(concate_clip)
            video_clip.write_videofile(clip_video_file, audio_codec="aac")
        else:
            clip_video_file = vedio_filename
            message = "No period found in the audio, return raw speech. You may check the recognition result and try other destination text."
            srt_clip = ''
        return clip_video_file, message, clip_srt