File size: 25,617 Bytes
55c7989
66e606c
61ca873
55c7989
 
 
9a42032
b39d769
1312451
be41edb
55c7989
e3f9642
1a902ed
e3f9642
bd8dd84
e3f9642
 
 
 
bd8dd84
e3f9642
 
 
 
bd8dd84
e3f9642
 
 
 
bd8dd84
e3f9642
 
 
 
bd8dd84
e3f9642
 
 
 
bd8dd84
e3f9642
 
 
 
bd8dd84
e3f9642
 
 
 
bd8dd84
e3f9642
 
 
1a902ed
09cabee
fccaaea
 
7d74f8e
1a902ed
 
 
 
cf5f1c9
 
f1a218d
 
55c7989
 
f1a218d
55c7989
 
f1a218d
 
 
 
55c7989
5f10ef2
55c7989
 
f1a218d
55c7989
5f10ef2
55c7989
 
698fd4d
cf5f1c9
 
da02d45
cf5f1c9
f1a218d
cf5f1c9
6113bd9
 
7baae45
 
55c7989
 
7baae45
55c7989
7baae45
55c7989
1a7e1ee
 
 
 
bd8dd84
55c7989
6113bd9
55c7989
 
 
 
 
 
bd773a2
 
6113bd9
9b3283d
 
6113bd9
9b3283d
55c7989
 
 
 
7d74f8e
55c7989
 
 
 
7d74f8e
55c7989
 
7d74f8e
55c7989
1a902ed
55c7989
 
bd8dd84
 
 
 
 
9b3283d
cf5f1c9
55c7989
 
cf5f1c9
6113bd9
55c7989
cf5f1c9
6113bd9
09cabee
55c7989
7d74f8e
fccaaea
 
1a902ed
 
 
cf5f1c9
fccaaea
 
 
be41edb
 
fccaaea
 
be41edb
 
 
fccaaea
cf5f1c9
1a902ed
cf5f1c9
48f0069
1a7e1ee
 
 
cf5f1c9
1a7e1ee
 
 
 
 
 
7baae45
1a902ed
cf5f1c9
7d74f8e
55c7989
 
 
 
 
 
 
 
6113bd9
55c7989
 
6113bd9
55c7989
 
 
6113bd9
 
55c7989
 
 
 
 
9a42032
55c7989
6113bd9
e3f9642
7d74f8e
6113bd9
e3f9642
6113bd9
 
 
 
 
 
7d74f8e
6113bd9
 
9a42032
 
6113bd9
da02d45
25b7b0c
55c7989
9b3283d
55c7989
 
 
 
9b3283d
 
9a42032
6113bd9
55c7989
cf5f1c9
 
55c7989
0d268b7
55c7989
 
 
0d268b7
55c7989
bd8dd84
ce7a58b
92d2ddc
55c7989
 
bd8dd84
 
 
55c7989
 
 
ce7a58b
55c7989
bd8dd84
9e140e9
fb45ef4
7fea39b
b9085da
55c7989
7fea39b
 
bd8dd84
92d2ddc
ce7a58b
 
 
 
55c7989
ce7a58b
55c7989
ce7a58b
bd8dd84
 
ce7a58b
a161330
55c7989
7fea39b
b9085da
7fea39b
bd8dd84
55c7989
bd8dd84
a161330
 
9518a49
 
 
55c7989
9518a49
 
55c7989
 
9518a49
 
 
55c7989
9518a49
 
55c7989
 
bd8dd84
55c7989
 
fdce050
 
 
 
55c7989
fdce050
9b3283d
5f10ef2
 
2d29c14
92d2ddc
 
 
9b3283d
e75254e
3cc60a3
e3f9642
 
 
3cc60a3
e3f9642
3cc60a3
e3f9642
e75254e
3cc60a3
f1a218d
 
9b3283d
 
e3f9642
 
7baae45
55c7989
 
7baae45
e3f9642
7baae45
55c7989
 
 
f2c3799
 
7baae45
 
55c7989
 
7baae45
bd8dd84
55c7989
bd8dd84
 
 
 
 
9e3f79b
9b3283d
55c7989
9b3283d
f1a218d
7baae45
f1a218d
7baae45
9b3283d
f1a218d
55c7989
f1a218d
9b3283d
f1a218d
 
 
 
1a902ed
f1a218d
6113bd9
f1a218d
 
 
 
1a902ed
f1a218d
 
7baae45
9b3283d
 
7baae45
 
55c7989
9b3283d
 
7baae45
 
f1a218d
7baae45
f1a218d
9b3283d
 
9a42032
7baae45
9a42032
9b3283d
 
2d29c14
7baae45
 
 
55c7989
7baae45
9a42032
f1a218d
9b3283d
9a42032
9b3283d
e3825f8
 
0b15a1d
e3825f8
55c7989
9b3283d
 
e3825f8
55c7989
e3825f8
 
55c7989
 
0b15a1d
cf5f1c9
66e606c
 
9a42032
fccaaea
 
 
 
 
be41edb
 
 
 
 
 
 
 
 
 
 
 
 
2d29c14
31aaedc
2d29c14
 
31aaedc
1312451
 
2d29c14
 
1312451
2d29c14
31aaedc
 
1312451
2d29c14
1312451
31aaedc
1312451
31aaedc
2d29c14
1312451
 
 
970cf06
 
1312451
 
a94e85f
2d29c14
 
55c7989
d438792
9a42032
fccaaea
 
 
 
 
 
1de6702
 
fd3d7cb
cf5f1c9
1312451
 
d438792
1312451
 
 
 
 
 
2d29c14
 
1312451
2d29c14
1312451
 
 
2d29c14
1de6702
 
 
 
 
 
 
 
 
1312451
9b3283d
 
 
 
 
 
 
2d29c14
 
9b3283d
55c7989
9b3283d
 
 
55c7989
9b3283d
 
 
 
 
 
55c7989
9b3283d
 
55c7989
9b3283d
 
 
55c7989
9b3283d
 
 
55c7989
9b3283d
 
 
 
 
55c7989
9a42032
9b3283d
 
 
 
55c7989
9a42032
9b3283d
 
 
 
55c7989
9b3283d
e3825f8
 
 
9ad8b62
 
 
 
55c7989
 
 
e3825f8
 
 
55c7989
9b3283d
e3825f8
 
 
9ad8b62
55c7989
 
 
e3825f8
9b3283d
54a5e67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
import os
import re
from pathlib import Path
from copy import copy, deepcopy
from csv import reader
from datetime import timedelta
import logging
import openai
from tqdm import tqdm
import dict_util

# punctuation dictionary for supported languages
punctuation_dict = {
    "EN": {
        "punc_str": ". , ? ! : ; - ( ) [ ] { }",
        "comma": ", ",
        "sentence_end": [".", "!", "?", ";"]
    },
    "ES": {
        "punc_str": ". , ? ! : ; - ( ) [ ] { } ¡ ¿",
        "comma": ", ",
        "sentence_end": [".", "!", "?", ";", "¡", "¿"]
    },
    "FR": {
        "punc_str": ".,?!:;«»—",
        "comma": ", ",
        "sentence_end": [".", "!", "?", ";"]
    },
    "DE": {
        "punc_str": ".,?!:;„“–",
        "comma": ", ",
        "sentence_end": [".", "!", "?", ";"]
    },
    "RU": {
        "punc_str": ".,?!:;-«»—",
        "comma": ", ",
        "sentence_end": [".", "!", "?", ";"]
    },
    "ZH": {
        "punc_str": "。,?!:;()",
        "comma": ",",
        "sentence_end": ["。", "!", "?"]
    },
    "JA": {
        "punc_str": "。、?!:;()",
        "comma": "、",
        "sentence_end": ["。", "!", "?"]
    },
    "AR": {
        "punc_str": ".,?!:;-()[]،؛ ؟ «»",
        "comma": "، ",
        "sentence_end": [".", "!", "?", ";", "؟"]
    },
}

dict_path = "./domain_dict"

class SrtSegment(object):
    def __init__(self, src_lang, tgt_lang, *args) -> None:
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang

        if isinstance(args[0], dict):
            segment = args[0]
            self.start = segment['start']
            self.end = segment['end']
            self.start_ms = int((segment['start'] * 100) % 100 * 10)
            self.end_ms = int((segment['end'] * 100) % 100 * 10)

            if self.start_ms == self.end_ms and int(segment['start']) == int(segment['end']):  # avoid empty time stamp
                self.end_ms += 500

            self.start_time = timedelta(seconds=int(segment['start']), milliseconds=self.start_ms)
            self.end_time = timedelta(seconds=int(segment['end']), milliseconds=self.end_ms)
            if self.start_ms == 0:
                self.start_time_str = str(0) + str(self.start_time).split('.')[0] + ',000'
            else:
                self.start_time_str = str(0) + str(self.start_time).split('.')[0] + ',' + \
                                      str(self.start_time).split('.')[1][:3]
            if self.end_ms == 0:
                self.end_time_str = str(0) + str(self.end_time).split('.')[0] + ',000'
            else:
                self.end_time_str = str(0) + str(self.end_time).split('.')[0] + ',' + str(self.end_time).split('.')[1][
                                                                                      :3]
            self.source_text = segment['text'].lstrip()
            self.duration = f"{self.start_time_str} --> {self.end_time_str}"
            self.translation = ""

        elif isinstance(args[0], list):
            self.source_text = args[0][2]
            self.duration = args[0][1]
            self.start_time_str = self.duration.split(" --> ")[0]
            self.end_time_str = self.duration.split(" --> ")[1]

            # parse the time to float
            self.start_ms = int(self.start_time_str.split(',')[1]) / 10
            self.end_ms = int(self.end_time_str.split(',')[1]) / 10
            start_list = self.start_time_str.split(',')[0].split(':')
            self.start = int(start_list[0]) * 3600 + int(start_list[1]) * 60 + int(start_list[2]) + self.start_ms / 100
            end_list = self.end_time_str.split(',')[0].split(':')
            self.end = int(end_list[0]) * 3600 + int(end_list[1]) * 60 + int(end_list[2]) + self.end_ms / 100
            if len(args[0]) < 5:
                self.translation = ""
            else:
                self.translation = args[0][3]
            

    def merge_seg(self, seg):
        """
        Merge the segment seg with the current segment in place.
        :param seg: Another segment that is strictly next to current one.
        :return: None
        """
        # assert seg.start_ms == self.end_ms, f"cannot merge discontinuous segments."
        self.source_text += f' {seg.source_text}'
        self.translation += f' {seg.translation}'
        self.end_time_str = seg.end_time_str
        self.end = seg.end
        self.end_ms = seg.end_ms
        self.duration = f"{self.start_time_str} --> {self.end_time_str}"

    def __add__(self, other):
        """
        Merge the segment seg with the current segment, and return the new constructed segment.
        No in-place modification.
        This is used for '+' operator.
        :param other: Another segment that is strictly next to added segment.
        :return: new segment of the two sub-segments
        """
        result = deepcopy(self)
        result.merge_seg(other)
        return result

    def remove_trans_punc(self) -> None:
        """
        remove punctuations in translation text
        :return: None
        """
        punc_str = punctuation_dict[self.tgt_lang]["punc_str"]
        for punc in punc_str:
            self.translation = self.translation.replace(punc, ' ')
        # translator = str.maketrans(punc, ' ' * len(punc))
        # self.translation = self.translation.translate(translator)

    def __str__(self) -> str:
        return f'{self.duration}\n{self.source_text}\n\n'

    def get_trans_str(self) -> str:
        return f'{self.duration}\n{self.translation}\n\n'

    def get_bilingual_str(self) -> str:
        return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'


class SrtScript(object):
    def __init__(self, src_lang, tgt_lang, segments, domain="General") -> None:
        self.domain = domain
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.segments = [SrtSegment(self.src_lang, self.tgt_lang, seg) for seg in segments]

        if self.domain != "General":
            if os.path.exists(f"{dict_path}/{self.domain}"):
                # TODO: load dictionary
                self.dict = dict_util.term_dict(f"{dict_path}/{self.domain}", src_lang, tgt_lang)
                print(self.dict["robo"])
                ...
            else:
                logging.error(f"domain {self.domain} doesn't exist, fallback to general domain, this will disable correct_with_force_term and spell_check_term")
                self.domain = "General"


    @classmethod
    def parse_from_srt_file(cls, src_lang, tgt_lang, path: str):
        with open(path, 'r', encoding="utf-8") as f:
            script_lines = [line.rstrip() for line in f.readlines()]
        bilingual = False
        if script_lines[2] != '' and script_lines[3] != '':
            bilingual = True
        segments = []
        if bilingual:
            for i in range(0, len(script_lines), 5):
                segments.append(list(script_lines[i:i + 5]))
        else:
            for i in range(0, len(script_lines), 4):
                segments.append(list(script_lines[i:i + 4]))

        return cls(src_lang, tgt_lang, segments)

    def merge_segs(self, idx_list) -> SrtSegment:
        """
        Merge entire segment list to a single segment
        :param idx_list: List of index to merge
        :return: Merged list
        """
        if not idx_list:
            raise NotImplementedError('Empty idx_list')
        seg_result = deepcopy(self.segments[idx_list[0]])
        if len(idx_list) == 1:
            return seg_result

        for idx in range(1, len(idx_list)):
            seg_result += self.segments[idx_list[idx]]

        return seg_result

    def form_whole_sentence(self):
        """
        Concatenate or Strip sentences and reconstruct segments list. This is because of
        improper segmentation from openai-whisper.
        :return: None
        """
        logging.info("Forming whole sentences...")
        merge_list = []  # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
        sentence = []
        ending_puncs = punctuation_dict[self.src_lang]["sentence_end"]
        # Get each entire sentence of distinct segments, fill indices to merge_list
        for i, seg in enumerate(self.segments):
            if seg.source_text[-1] in ending_puncs and len(seg.source_text) > 10 and 'vs.' not in seg.source_text:
                sentence.append(i)
                merge_list.append(sentence)
                sentence = []
            else:
                sentence.append(i)

        # Reconstruct segments, each with an entire sentence
        segments = []
        for idx_list in merge_list:
            if len(idx_list) > 1:
                logging.info("merging segments: %s", idx_list)
            segments.append(self.merge_segs(idx_list))

        self.segments = segments

    def remove_trans_punctuation(self):
        """
        Post-process: remove all punc after translation and split
        :return: None
        """
        for i, seg in enumerate(self.segments):
            seg.remove_trans_punc()
        logging.info("Removed punctuation in translation.")

    def set_translation(self, translate: str, id_range: tuple, model, video_name, video_link=None):
        start_seg_id = id_range[0]
        end_seg_id = id_range[1]

        src_text = ""
        for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
            src_text += seg.source_text
            src_text += '\n\n'

        def inner_func(target, input_str):
            # handling merge sentences issue.
            response = openai.ChatCompletion.create(
                model="gpt-4",
                messages=[
                    {"role": "system",
                     "content": "Your task is to merge or split sentences into a specified number of lines as required. You need to ensure the meaning of the sentences as much as possible, but when necessary, a sentence can be divided into two lines for output"},
                    {"role": "system", "content": "Note: You only need to output the processed {} sentences. If you need to output a sequence number, please separate it with a colon.".format(self.tgt_lang)},
                    {"role": "user", "content": 'Please split or combine the following sentences into {} sentences:\n{}'.format(target, input_str)}
                ],
                temperature=0.15
            )
            return response['choices'][0]['message']['content'].strip()

        # handling merge sentences issue.
        lines = translate.split('\n\n')
        if len(lines) < (end_seg_id - start_seg_id + 1):
            count = 0
            solved = True
            while count < 5 and len(lines) != (end_seg_id - start_seg_id + 1):
                count += 1
                print("Solving Unmatched Lines|iteration {}".format(count))
                logging.error("Solving Unmatched Lines|iteration {}".format(count))

                flag = True
                while flag:
                    flag = False
                    try:
                        translate = inner_func(end_seg_id - start_seg_id + 1, translate)
                    except Exception as e:
                        print("An error has occurred during solving unmatched lines:", e)
                        print("Retrying...")
                        logging.error("An error has occurred during solving unmatched lines:", e)
                        logging.error("Retrying...")
                        flag = True
                lines = translate.split('\n')

            if len(lines) < (end_seg_id - start_seg_id + 1):
                solved = False
                print("Failed Solving unmatched lines, Manually parse needed")
                logging.error("Failed Solving unmatched lines, Manually parse needed")

            # FIXME: put the error log in our log file
            if not os.path.exists("./logs"):
                os.mkdir("./logs")
            if video_link:
                log_file = "./logs/log_link.csv"
                log_exist = os.path.exists(log_file)
                with open(log_file, "a") as log:
                    if not log_exist:
                        log.write("range_of_text,iterations_solving,solved,file_length,video_link" + "\n")
                    log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
                        len(self.segments)) + ',' + video_link + "\n")
            else:
                log_file = "./logs/log_name.csv"
                log_exist = os.path.exists(log_file)
                with open(log_file, "a") as log:
                    if not log_exist:
                        log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n")
                    log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
                        len(self.segments)) + ',' + video_name + "\n")
            # print(lines)

        for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
            # naive way to due with merge translation problem
            # TODO: need a smarter solution

            if i < len(lines):
                if "Note:" in lines[i]:  # to avoid note
                    lines.remove(lines[i])
                    max_num -= 1
                    if i == len(lines) - 1:
                        break
                if lines[i][0] in [' ', '\n']:
                    lines[i] = lines[i][1:]
                seg.translation = lines[i]

    def split_seg(self, seg, text_threshold, time_threshold):
        # evenly split seg to 2 parts and add new seg into self.segments
        # ignore the initial comma to solve the recursion problem
        src_comma_str = punctuation_dict[self.src_lang]["comma"]
        tgt_comma_str = punctuation_dict[self.tgt_lang]["comma"]

        if len(seg.source_text) > 2:
            if seg.source_text[:2] == src_comma_str:
                seg.source_text = seg.source_text[2:]
        if seg.translation[0] == tgt_comma_str:
            seg.translation = seg.translation[1:]

        source_text = seg.source_text
        translation = seg.translation

        # split the text based on commas
        src_commas = [m.start() for m in re.finditer(src_comma_str, source_text)]
        trans_commas = [m.start() for m in re.finditer(tgt_comma_str, translation)]
        if len(src_commas) != 0:
            src_split_idx = src_commas[len(src_commas) // 2] if len(src_commas) % 2 == 1 else src_commas[
                len(src_commas) // 2 - 1]
        else:
            # split the text based on spaces
            src_space = [m.start() for m in re.finditer(' ', source_text)]
            if len(src_space) > 0:
                src_split_idx = src_space[len(src_space) // 2] if len(src_space) % 2 == 1 else src_space[
                    len(src_space) // 2 - 1]
            else:
                src_split_idx = 0

        if len(trans_commas) != 0:
            trans_split_idx = trans_commas[len(trans_commas) // 2] if len(trans_commas) % 2 == 1 else trans_commas[
                len(trans_commas) // 2 - 1]
        else:
            trans_split_idx = len(translation) // 2

            # to avoid split English word
            for i in range(trans_split_idx, len(translation)):
                if not translation[i].encode('utf-8').isalpha():
                    trans_split_idx = i
                    break

        # split the time duration based on text length
        time_split_ratio = trans_split_idx / (len(seg.translation) - 1)

        src_seg1 = source_text[:src_split_idx]
        src_seg2 = source_text[src_split_idx:]
        trans_seg1 = translation[:trans_split_idx]
        trans_seg2 = translation[trans_split_idx:]

        start_seg1 = seg.start
        end_seg1 = start_seg2 = seg.start + (seg.end - seg.start) * time_split_ratio
        end_seg2 = seg.end

        seg1_dict = {}
        seg1_dict['text'] = src_seg1
        seg1_dict['start'] = start_seg1
        seg1_dict['end'] = end_seg1
        seg1 = SrtSegment(self.src_lang, self.tgt_lang, seg1_dict)
        seg1.translation = trans_seg1

        seg2_dict = {}
        seg2_dict['text'] = src_seg2
        seg2_dict['start'] = start_seg2
        seg2_dict['end'] = end_seg2
        seg2 = SrtSegment(self.src_lang, self.tgt_lang, seg2_dict)
        seg2.translation = trans_seg2

        result_list = []
        if len(seg1.translation) > text_threshold and (seg1.end - seg1.start) > time_threshold:
            result_list += self.split_seg(seg1, text_threshold, time_threshold)
        else:
            result_list.append(seg1)

        if len(seg2.translation) > text_threshold and (seg2.end - seg2.start) > time_threshold:
            result_list += self.split_seg(seg2, text_threshold, time_threshold)
        else:
            result_list.append(seg2)

        return result_list

    def check_len_and_split(self, text_threshold=30, time_threshold=1.0):
        # if sentence length >= threshold and sentence duration > time_threshold, split this segments to two
        logging.info("performing check_len_and_split")
        segments = []
        for i, seg in enumerate(self.segments):
            if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
                seg_list = self.split_seg(seg, text_threshold, time_threshold)
                logging.info("splitting segment {} in to {} parts".format(i + 1, len(seg_list)))
                segments += seg_list
            else:
                segments.append(seg)

        self.segments = segments
        logging.info("check_len_and_split finished")

    def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0):
        # DEPRECATED
        # if sentence length >= text_threshold, split this segments to two
        start_seg_id = range[0]
        end_seg_id = range[1]
        extra_len = 0
        segments = []
        for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
            if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
                seg_list = self.split_seg(seg, text_threshold, time_threshold)
                segments += seg_list
                extra_len += len(seg_list) - 1
            else:
                segments.append(seg)

        self.segments[start_seg_id - 1:end_seg_id] = segments
        return extra_len

    def correct_with_force_term(self):
        ## force term correction
        logging.info("performing force term correction")

        # check domain
        if self.domain == "General":
            logging.info("General domain could not perform correct_with_force_term. skip this step.")
            pass
        else:
            keywords = list(self.dict.keys())
            keywords.sort(key=lambda x: len(x), reverse=True)

            for word in keywords:
                for i, seg in enumerate(self.segments):
                    if word in seg.source_text.lower():
                        seg.source_text = re.sub(fr"({word}es|{word}s?)\b", "{}".format(self.dict.get(word)),
                                                seg.source_text, flags=re.IGNORECASE)
                        logging.info(
                            "replace term: " + word + " --> " + self.dict.get(word) + " in time stamp {}".format(
                                i + 1))
                        logging.info("source text becomes: " + seg.source_text)

    comp_dict = []

    def fetchfunc(self, word, threshold):
        import enchant
        result = word
        distance = 0
        threshold = threshold * len(word)
        if len(self.comp_dict) == 0:
            with open("./finetune_data/dict_freq.txt", 'r', encoding='utf-8') as f:
                self.comp_dict = {rows[0]: 1 for rows in reader(f)}
        temp = ""
        for matched in self.comp_dict:
            if (" " in matched and " " in word) or (" " not in matched and " " not in word):
                if enchant.utils.levenshtein(word, matched) < enchant.utils.levenshtein(word, temp):
                    temp = matched
        if enchant.utils.levenshtein(word, temp) < threshold:
            distance = enchant.utils.levenshtein(word, temp)
            result = temp
        return distance, result

    def extract_words(self, sentence, n):
        # this function split the sentence to chunks by n of words
        # e.g. sentence: "this, is a sentence", n = 2
        # result: ["this,", "is", "a", ["sentence"], ["this,", "is"], "is a", "a sentence"]
        words = sentence.split()
        res = []
        for j in range(n, 0, -1):
            res += [words[i:i + j] for i in range(len(words) - j + 1)]
        return res

    def spell_check_term(self):
        logging.info("performing spell check")

        # check domain
        if self.domain == "General":
            logging.info("General domain could not perform spell_check_term. skip this step.")
            pass

        import enchant
        dict = enchant.Dict('en_US')
        term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')

        for seg in tqdm(self.segments):
            ready_words = self.extract_words(seg.source_text, 2)
            for i in range(len(ready_words)):
                word_list = ready_words[i]
                word, real_word, pos = self.get_real_word(word_list)
                if not dict.check(real_word) and not term_spellDict.check(real_word):
                    distance, correct_term = self.fetchfunc(real_word, 0.3)
                    if distance != 0:
                        seg.source_text = re.sub(word[:pos], correct_term, seg.source_text, flags=re.IGNORECASE)
                        logging.info(
                            "replace: " + word[:pos] + " to " + correct_term + "\t distance = " + str(distance))

    def get_real_word(self, word_list: list):
        word = ""
        for w in word_list:
            word += f"{w} "
        word = word[:-1]  # "this, is"
        if word[-2:] == ".\n":
            real_word = word[:-2].lower()
            n = -2
        elif word[-1:] in [".", "\n", ",", "!", "?"]:
            real_word = word[:-1].lower()
            n = -1
        else:
            real_word = word.lower()
            n = 0
        return word, real_word, len(word) + n

    ## WRITE AND READ FUNCTIONS ##

    def get_source_only(self):
        # return a string with pure source text
        result = ""
        for i, seg in enumerate(self.segments):
            result += f'{seg.source_text}\n\n\n'  # f'SENTENCE {i+1}: {seg.source_text}\n\n\n'

        return result

    def reform_src_str(self):
        result = ""
        for i, seg in enumerate(self.segments):
            result += f'{i + 1}\n'
            result += str(seg)
        return result

    def reform_trans_str(self):
        result = ""
        for i, seg in enumerate(self.segments):
            result += f'{i + 1}\n'
            result += seg.get_trans_str()
        return result

    def form_bilingual_str(self):
        result = ""
        for i, seg in enumerate(self.segments):
            result += f'{i + 1}\n'
            result += seg.get_bilingual_str()
        return result

    def write_srt_file_src(self, path: str):
        # write srt file to path
        with open(path, "w", encoding='utf-8') as f:
            f.write(self.reform_src_str())
        pass

    def write_srt_file_translate(self, path: str):
        logging.info("writing to " + path)
        with open(path, "w", encoding='utf-8') as f:
            f.write(self.reform_trans_str())
        pass

    def write_srt_file_bilingual(self, path: str):
        logging.info("writing to " + path)
        with open(path, "w", encoding='utf-8') as f:
            f.write(self.form_bilingual_str())
        pass

    def realtime_write_srt(self, path, range, length, idx):
        # DEPRECATED
        start_seg_id = range[0]
        end_seg_id = range[1]
        with open(path, "a", encoding='utf-8') as f:
            # for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id+length]):
            #     f.write(f'{i+idx}\n')
            #     f.write(seg.get_trans_str())
            for i, seg in enumerate(self.segments):
                if i < range[0] - 1: continue
                if i >= range[1] + length: break
                f.write(f'{i + idx}\n')
                f.write(seg.get_trans_str())
        pass

    def realtime_bilingual_write_srt(self, path, range, length, idx):
        # DEPRECATED
        start_seg_id = range[0]
        end_seg_id = range[1]
        with open(path, "a", encoding='utf-8') as f:
            for i, seg in enumerate(self.segments):
                if i < range[0] - 1: continue
                if i >= range[1] + length: break
                f.write(f'{i + idx}\n')
                f.write(seg.get_bilingual_str())
        pass

def split_script(script_in, chunk_size=1000):
    script_split = script_in.split('\n\n')
    script_arr = []
    range_arr = []
    start = 1
    end = 0
    script = ""
    for sentence in script_split:
        if len(script) + len(sentence) + 1 <= chunk_size:
            script += sentence + '\n\n'
            end += 1
        else:
            range_arr.append((start, end))
            start = end + 1
            end += 1
            script_arr.append(script.strip())
            script = sentence + '\n\n'
    if script.strip():
        script_arr.append(script.strip())
        range_arr.append((start, len(script_split) - 1))

    assert len(script_arr) == len(range_arr)
    return script_arr, range_arr