File size: 8,677 Bytes
09cabee
 
 
66e606c
 
b39d769
09cabee
 
cf5f1c9
 
 
5f10ef2
 
 
 
 
 
 
 
 
 
 
 
cf5f1c9
 
 
 
 
 
6113bd9
 
cf5f1c9
 
6113bd9
 
 
 
 
 
 
cf5f1c9
6113bd9
cf5f1c9
 
6113bd9
cf5f1c9
 
6113bd9
09cabee
 
 
 
 
 
 
cf5f1c9
 
 
 
 
 
 
 
 
 
09cabee
cf5f1c9
 
6113bd9
 
 
 
 
 
 
 
 
 
 
5f10ef2
6113bd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf5f1c9
 
 
b39d769
cf5f1c9
b39d769
6113bd9
b39d769
 
6113bd9
b39d769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf5f1c9
 
fdce050
 
 
 
 
 
5f10ef2
 
b39d769
 
 
 
09cabee
6113bd9
 
 
 
 
 
 
 
09cabee
cf5f1c9
 
 
6113bd9
 
cf5f1c9
 
 
 
 
6113bd9
 
cf5f1c9
 
 
 
 
6113bd9
 
cf5f1c9
 
 
 
 
6113bd9
 
cf5f1c9
 
 
 
09cabee
cf5f1c9
 
09cabee
 
cf5f1c9
 
 
 
 
 
 
 
 
 
66e606c
 
 
 
cf5f1c9
66e606c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf5f1c9
 
 
 
09cabee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
from datetime import timedelta
import os
import whisper
from csv import reader
import re
import openai

class SRT_segment(object):
    def __init__(self, *args) -> None:
        if isinstance(args[0], dict):
            segment = args[0]
            start_ms = int((segment['start']*100)%100*10)
            end_ms = int((segment['end']*100)%100*10)
            start_time = str(timedelta(seconds=int(segment['start']), milliseconds=start_ms))
            end_time = str(timedelta(seconds=int(segment['end']), milliseconds=end_ms))
            if start_ms == 0:
                self.start_time_str = str(0)+start_time.split('.')[0]+',000'
            else:
                self.start_time_str = str(0)+start_time.split('.')[0]+','+start_time.split('.')[1][:3]
            if end_ms == 0:
                self.end_time_str = str(0)+end_time.split('.')[0]+',000'
            else:
                self.end_time_str = str(0)+end_time.split('.')[0]+','+end_time.split('.')[1][:3]
            self.source_text = segment['text']
            self.duration = f"{self.start_time_str} --> {self.end_time_str}"
            self.translation = ""
        elif isinstance(args[0], list):
            self.source_text = args[0][2]
            self.duration = args[0][1]
            self.start_time_str = self.duration.split(" --> ")[0]
            self.end_time_str = self.duration.split(" --> ")[1]
            self.translation = ""
    
    def merge_seg(self, seg):
        self.source_text += seg.source_text
        self.translation += seg.translation
        self.end_time_str = seg.end_time_str
        self.duration = f"{self.start_time_str} --> {self.end_time_str}"
        pass
    
    def __str__(self) -> str:
        return  f'{self.duration}\n{self.source_text}\n\n'
    
    def get_trans_str(self) -> str:
        return f'{self.duration}\n{self.translation}\n\n'
    
    def get_bilingual_str(self) -> str:
        return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'

class SRT_script():
    def __init__(self, segments) -> None:
        self.segments = []
        for seg in segments:
            srt_seg = SRT_segment(seg)
            self.segments.append(srt_seg)

    @classmethod
    def parse_from_srt_file(cls, path:str):
        with open(path, 'r', encoding="utf-8") as f:
            script_lines = f.read().splitlines() 

        segments = []
        for i in range(len(script_lines)):
            if i % 4 == 0:
                segments.append(list(script_lines[i:i+4]))
        
        return cls(segments)

    def merge_segs(self, idx_list) -> SRT_segment:
        final_seg = self.segments[idx_list[0]]
        if len(idx_list) == 1:
            return final_seg
        
        for idx in range(1, len(idx_list)):
            final_seg.merge_seg(self.segments[idx_list[idx]])
        
        return final_seg

    def form_whole_sentence(self):
        merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
        sentence = []
        for i, seg in enumerate(self.segments):
            if seg.source_text[-1] == '.':
                sentence.append(i)
                merge_list.append(sentence)
                sentence = []
            else:
                sentence.append(i)

        segments = []
        for idx_list in merge_list:
            segments.append(self.merge_segs(idx_list))
        
        self.segments = segments # need memory release?

    def set_translation(self, translate:str, id_range:tuple):
        start_seg_id = id_range[0]
        end_seg_id = id_range[1]
        
        lines = translate.split('\n\n')
        
        if len(lines) != (end_seg_id - start_seg_id + 1):
            input_str = "\n";
            #initialize GPT input
            for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
                input_str += 'Sentence %d: ' %(i+1)+ seg.source_text + '\n'
                #Append to prompt string
                #Adds sentence index let GPT keep track of sentence breaks
            input_str += translate    
            #append translate to prompt
            response = openai.ChatCompletion.create(
                model="gpt-3.5-turbo",
                messages = [
                    {"role": "system", "content": "You are a helpful assistant that help calibrates English to Chinese subtitle translations in starcraft2."},
                    {"role": "system", "content": "You are provided with a translated Chinese transcript, you need to reformat the Chinese sentence to match the meaning and sentence number as the English transcript"},
                    {"role": "system", "content": "There is no need for you to add any comments or notes, and do not modify the English transcript."},
                    {"role": "user", "content": 'Reformat the Chinese with the English transcript given: "{}"'.format(input_str)}
                ],
               temperature=0.15
            )
            
            translate = response['choices'][0]['text'].strip()
            

            #print(id_range)
            #for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
            #    print(seg.source_text)
            #print(translate)

        for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
            # naive way to due with merge translation problem
            # TODO: need a smarter solution

            if i < len(lines):
                if "(Note:" in lines[i]: # to avoid note
                    lines.remove(lines[i])
                    if i == len(lines) - 1:
                        break
                try:
                    seg.translation = lines[i].split(":" or ": ")[1]
                except:
                    seg.translation = lines[i]
        pass
    
    def split_seg(self, seg_id):
        # TODO: evenly split seg to 2 parts and add new seg into self.segments
        pass

    def check_len_and_split(self, threshold):
        # TODO: if sentence length >= threshold, split this segments to two
        pass

    def get_source_only(self):
        # return a string with pure source text
        result = ""
        for i, seg in enumerate(self.segments):
            result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
        
        return result
    
    def reform_src_str(self):
        result = ""
        for i, seg in enumerate(self.segments):
            result += f'{i+1}\n'
            result += str(seg)
        return result

    def reform_trans_str(self):
        result = ""
        for i, seg in enumerate(self.segments):
            result += f'{i+1}\n'
            result += seg.get_trans_str()
        return result
    
    def form_bilingual_str(self):
        result = ""
        for i, seg in enumerate(self.segments):
            result += f'{i+1}\n'
            result += seg.get_bilingual_str()
        return result

    def write_srt_file_src(self, path:str):
        # write srt file to path
        with open(path, "w", encoding='utf-8') as f:
            f.write(self.reform_src_str())
        pass

    def write_srt_file_translate(self, path:str):
        with open(path, "w", encoding='utf-8') as f:
            f.write(self.reform_trans_str())
        pass

    def write_srt_file_bilingual(self, path:str):
        with open(path, "w", encoding='utf-8') as f:
            f.write(self.form_bilingual_str())
        pass

    def correct_with_force_term(self):
        ## force term correction
        # TODO: shortcut translation i.e. VA, ob
        # TODO: variety of translation

        # load term dictionary
        with open("finetune_data/dict.csv",'r', encoding='utf-8') as f:
            csv_reader = reader(f)
            term_dict = {rows[0]:rows[1] for rows in csv_reader}

        # change term
        for seg in self.segments:
            ready_words = re.sub('\n', '\n ', seg.source_text).split(" ")
            for i in range(len(ready_words)):
                word = ready_words[i]
                if word[-2:] == ".\n" :
                    if word[:-2].lower() in term_dict :
                        new_word = word.replace(word[:-2], term_dict.get(word[:-2].lower())) + ' '
                        ready_words[i] = new_word
                    else:
                        ready_words[i] = word + ' '
                elif word.lower() in term_dict :
                    new_word = word.replace(word,term_dict.get(word.lower())) + ' '
                    ready_words[i] = new_word
                else :
                    ready_words[i]= word + ' '
            seg.source_text = re.sub('\n ', '\n', "".join(ready_words))
        
        print(self)
        pass