metricv commited on
Commit
f953fd7
1 Parent(s): cece17c

Initial commit

Browse files
Files changed (9) hide show
  1. .gitignore +1 -0
  2. .gitmodules +3 -0
  3. __init__.py +0 -0
  4. data +1 -0
  5. extract_ass.py +16 -0
  6. model.py +284 -0
  7. model_consts.py +9 -0
  8. train.py +31 -0
  9. utils.py +254 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "data"]
2
+ path = data
3
+ url = git@hf.co:datasets/metricv/metricsubs-segmenter
__init__.py ADDED
File without changes
data ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit f8f1b533b09e44d6b885dd9931a9a56f8f8ce319
extract_ass.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ass
2
+ import os
3
+ import sys
4
+
5
+ if __name__ == "__main__":
6
+ filename = sys.argv[1]
7
+
8
+ with open(filename, "r", encoding='utf-8-sig') as fin:
9
+ doc = ass.parse(fin)
10
+
11
+ for e in doc.events:
12
+ if isinstance(e, ass.Dialogue) and e.style == "英":
13
+ print(e.text.strip())
14
+
15
+
16
+
model.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import torch
3
+ from torch import nn
4
+ from torch.utils.data import Dataset, DataLoader
5
+ import numpy as np
6
+ from os import listdir
7
+ from os.path import isfile, join
8
+
9
+ if __package__ == None or __package__ == "":
10
+ from utils import tag_training_data, get_upenn_tags_dict, parse_tags
11
+ else:
12
+ from .utils import tag_training_data, get_upenn_tags_dict, parse_tags
13
+
14
+ # Model Type 1: LSTM with 1-logit lookahead.
15
+ class SegmentorDataset(Dataset):
16
+ def __init__(self, datapoints):
17
+ self.datapoints = [(torch.from_numpy(k).float(), torch.tensor([t]).float()) for k, t in datapoints]
18
+
19
+ def __len__(self):
20
+ return len(self.datapoints)
21
+
22
+ def __getitem__(self, idx):
23
+ return self.datapoints[idx][0], self.datapoints[idx][1]
24
+
25
+ class RNN(nn.Module):
26
+ def __init__(self, input_size, hidden_size, num_layers, device=None):
27
+ super(RNN, self).__init__()
28
+
29
+ if device == None:
30
+ if torch.cuda.is_available():
31
+ self.device = "cuda"
32
+ else:
33
+ self.device = "cpu"
34
+ else:
35
+ self.device = device
36
+
37
+ self.num_layers = num_layers
38
+ self.hidden_size = hidden_size
39
+ self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
40
+
41
+ self.fc = nn.Linear(hidden_size, 1)
42
+
43
+ def forward(self, x):
44
+ h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=self.device)
45
+ c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=self.device)
46
+ out, _ = self.rnn(x, (h0, c0))
47
+
48
+ out = out[:, -1, :]
49
+
50
+ out = self.fc(out)
51
+
52
+ return out
53
+
54
+ # Model 2: Bidirectional LSTM with entire sequence context (hopefully)
55
+ class SegmentorDatasetDirectTag(Dataset):
56
+ def __init__(self, document_root: str):
57
+ self.tags_dict = get_upenn_tags_dict()
58
+ self.datapoints = []
59
+ self.eye = np.eye(len(self.tags_dict))
60
+
61
+ files = listdir(document_root)
62
+ for f in files:
63
+ if f.endswith(".txt"):
64
+ fname = join(document_root, f)
65
+ print(f"Loaded datafile: {fname}")
66
+ reconstructed_tags = tag_training_data(fname)
67
+ input, tag = parse_tags(reconstructed_tags)
68
+ self.datapoints.append((
69
+ np.array(input),
70
+ np.array(tag)
71
+ ))
72
+
73
+ def __len__(self):
74
+ return len(self.datapoints)
75
+
76
+ def __getitem__(self, idx):
77
+ item = self.datapoints[idx]
78
+ return torch.from_numpy(self.eye[item[0]]).float(), torch.from_numpy(item[1]).float()
79
+
80
+ # The same dataset without one-hot embedding of the input.
81
+ class SegmentorDatasetNonEmbed(Dataset):
82
+ def __init__(self, document_root: str):
83
+ self.datapoints = []
84
+
85
+ files = listdir(document_root)
86
+ for f in files:
87
+ if f.endswith(".txt"):
88
+ fname = join(document_root, f)
89
+ print(f"Loaded datafile: {fname}")
90
+ reconstructed_tags = tag_training_data(fname)
91
+ input, tag = parse_tags(reconstructed_tags)
92
+ self.datapoints.append((
93
+ np.array(input),
94
+ np.array(tag)
95
+ ))
96
+
97
+ def __len__(self):
98
+ return len(self.datapoints)
99
+
100
+ def __getitem__(self, idx):
101
+ item = self.datapoints[idx]
102
+ return torch.from_numpy(item[0]).int(), torch.from_numpy(item[1]).float()
103
+
104
+ class BidirLSTMSegmenter(nn.Module):
105
+ def __init__(self, input_size, hidden_size, num_layers, device = None):
106
+ super(BidirLSTMSegmenter, self).__init__()
107
+
108
+ if device == None:
109
+ if torch.cuda.is_available():
110
+ self.device = "cuda"
111
+ else:
112
+ self.device = "cpu"
113
+ else:
114
+ self.device = device
115
+
116
+ self.num_layers = num_layers
117
+ self.hidden_size = hidden_size
118
+ self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True, device = self.device)
119
+
120
+ self.fc = nn.Linear(2*hidden_size, 1, device = self.device)
121
+ self.final = nn.Sigmoid()
122
+
123
+ def forward(self, x):
124
+ h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size, device=self.device)
125
+ c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size, device=self.device)
126
+ out, _ = self.rnn(x, (h0, c0))
127
+
128
+ # out_fced = [torch.zeros((out.shape[0], out.shape[1]), device=device)]
129
+ # # Shape of out: [batch, seq_length, 256 (num_directions * hidden_size)]
130
+ # for i in range(out.shape[1]):
131
+ # out_fced[:, i] = self.fc(out[:, i, :])[0]
132
+
133
+ out_fced = self.fc(out)[:, :, 0]
134
+
135
+ # Shape of out:
136
+
137
+ return self.final(out_fced)
138
+
139
+ class BidirLSTMSegmenterWithEmbedding(nn.Module):
140
+ def __init__(self, input_size, embedding_size, hidden_size, num_layers, device = None):
141
+ super(BidirLSTMSegmenterWithEmbedding, self).__init__()
142
+
143
+ if device == None:
144
+ if torch.cuda.is_available():
145
+ self.device = "cuda"
146
+ else:
147
+ self.device = "cpu"
148
+ else:
149
+ self.device = device
150
+
151
+ self.num_layers = num_layers
152
+ self.hidden_size = hidden_size
153
+ self.embedding_size = embedding_size
154
+
155
+ self.embedding = nn.Embedding(input_size, embedding_dim=embedding_size, device = self.device)
156
+ self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers, batch_first=True, bidirectional=True, device = self.device)
157
+
158
+ self.fc = nn.Linear(2*hidden_size, 1, device = self.device)
159
+ self.final = nn.Sigmoid()
160
+
161
+ def forward(self, x):
162
+ h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size, device=self.device)
163
+ c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size, device=self.device)
164
+ embedded = self.embedding(x)
165
+ out, _ = self.rnn(embedded, (h0, c0))
166
+
167
+ # out_fced = [torch.zeros((out.shape[0], out.shape[1]), device=device)]
168
+ # # Shape of out: [batch, seq_length, 256 (num_directions * hidden_size)]
169
+ # for i in range(out.shape[1]):
170
+ # out_fced[:, i] = self.fc(out[:, i, :])[0]
171
+
172
+ out_fced = self.fc(out)[:, :, 0]
173
+
174
+ # Shape of out:
175
+
176
+ return self.final(out_fced)
177
+
178
+ def collate_fn_padd(batch):
179
+ '''
180
+ Padds batch of variable length
181
+
182
+ note: it converts things ToTensor manually here since the ToTensor transform
183
+ assume it takes in images rather than arbitrary tensors.
184
+ '''
185
+ ## get sequence lengths
186
+ inputs = [i[0] for i in batch]
187
+ tags = [i[1] for i in batch]
188
+
189
+ padded_input = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True)
190
+ combined_outputs = torch.nn.utils.rnn.pad_sequence(tags, batch_first=True)
191
+
192
+ ## compute mask
193
+ return (padded_input, combined_outputs)
194
+
195
+ def get_dataloader(dataset: SegmentorDataset, batch_size):
196
+ return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_padd)
197
+
198
+ def train_model(model: RNN,
199
+ dataset,
200
+ lr = 1e-3,
201
+ num_epochs = 3,
202
+ batch_size = 100,
203
+ ):
204
+ train_loader = get_dataloader(dataset, batch_size=batch_size)
205
+
206
+ n_total_steps = len(train_loader)
207
+ criterion = nn.MSELoss()
208
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
209
+ device = model.device
210
+
211
+ for epoch in range(num_epochs):
212
+ for i, (input, tags) in enumerate(train_loader):
213
+ input = input.to(device)
214
+ tags = tags.to(device)
215
+
216
+ outputs = model(input)
217
+ loss = criterion(outputs, tags)
218
+
219
+ optimizer.zero_grad()
220
+ loss.backward()
221
+ optimizer.step()
222
+
223
+ if i%100 == 0:
224
+ print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss [{loss.item():.4f}]")
225
+
226
+ def train_bidirlstm_model(model: BidirLSTMSegmenter,
227
+ dataset: SegmentorDatasetDirectTag,
228
+ lr = 1e-3,
229
+ num_epochs = 3,
230
+ batch_size = 1,
231
+ ):
232
+ train_loader = get_dataloader(dataset, batch_size=batch_size)
233
+
234
+ n_total_steps = len(train_loader)
235
+ criterion = nn.BCELoss()
236
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
237
+ device = model.device
238
+
239
+ for epoch in range(num_epochs):
240
+ for i, (input, tags) in enumerate(train_loader):
241
+ input = input.to(device)
242
+ tags = tags.to(device)
243
+
244
+ optimizer.zero_grad()
245
+
246
+ outputs = model(input)
247
+
248
+ loss = criterion(outputs, tags)
249
+
250
+ loss.backward()
251
+ optimizer.step()
252
+
253
+ if i%10 == 0:
254
+ print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss [{loss.item():.4f}]")
255
+
256
+ def train_bidirlstm_embedding_model(model: BidirLSTMSegmenterWithEmbedding,
257
+ dataset: SegmentorDatasetNonEmbed,
258
+ lr = 1e-3,
259
+ num_epochs = 3,
260
+ batch_size = 1,
261
+ ):
262
+ train_loader = get_dataloader(dataset, batch_size=batch_size)
263
+
264
+ n_total_steps = len(train_loader)
265
+ criterion = nn.BCELoss()
266
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
267
+ device = model.device
268
+
269
+ for epoch in range(num_epochs):
270
+ for i, (input, tags) in enumerate(train_loader):
271
+ input = input.to(device)
272
+ tags = tags.to(device)
273
+
274
+ optimizer.zero_grad()
275
+
276
+ outputs = model(input)
277
+
278
+ loss = criterion(outputs, tags)
279
+
280
+ loss.backward()
281
+ optimizer.step()
282
+
283
+ if i%10 == 0:
284
+ print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss [{loss.item():.4f}]")
model_consts.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ if __package__ == None or __package__ == "":
2
+ from utils import get_upenn_tags_dict
3
+ else:
4
+ from .utils import get_upenn_tags_dict
5
+
6
+ input_size = len(get_upenn_tags_dict())
7
+ embedding_size = 128
8
+ hidden_size = 128
9
+ num_layers = 2
train.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+
4
+ if __package__ == None or __package__ == "":
5
+ from model import BidirLSTMSegmenter, SegmentorDatasetDirectTag, train_bidirlstm_model
6
+ from model import BidirLSTMSegmenterWithEmbedding, SegmentorDatasetNonEmbed, train_bidirlstm_embedding_model
7
+ from utils import get_upenn_tags_dict
8
+ from model_consts import input_size, embedding_size, hidden_size, num_layers
9
+ data_path = "data"
10
+ else:
11
+ from .model import BidirLSTMSegmenter, SegmentorDatasetDirectTag, train_bidirlstm_model
12
+ from .model import BidirLSTMSegmenterWithEmbedding, SegmentorDatasetNonEmbed, train_bidirlstm_embedding_model
13
+ from .utils import get_upenn_tags_dict
14
+ from .model_consts import input_size, embedding_size, hidden_size, num_layers
15
+ data_path = "segmenter/data"
16
+
17
+ device = "cuda"
18
+
19
+ if __name__ == "__main__":
20
+ dataset = SegmentorDatasetNonEmbed(data_path)
21
+ model = BidirLSTMSegmenterWithEmbedding(input_size, embedding_size, hidden_size, num_layers, device)
22
+
23
+ if os.path.exists("segmenter.ckpt") and os.path.isfile("segmenter.ckpt"):
24
+ print("Loading checkpoint. If you want to start from scratch, remove segmenter.ckpt.")
25
+ model.load_state_dict(torch.load("segmenter.ckpt"))
26
+
27
+ model.to(device)
28
+
29
+ train_bidirlstm_embedding_model(model, dataset, num_epochs=150, batch_size=2)
30
+
31
+ torch.save(model.state_dict(), "segmenter.ckpt")
utils.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ from nltk.tag import PerceptronTagger
3
+ from stable_whisper.result import WordTiming
4
+ import numpy as np
5
+ import torch
6
+
7
+ def bind_wordtimings_to_tags(wt: list[WordTiming]):
8
+ raw_words = [w.word for w in wt]
9
+
10
+ tokenized_raw_words = []
11
+ tokens_wordtiming_map = []
12
+
13
+ for word in raw_words:
14
+ tokens_word = nltk.word_tokenize(word)
15
+ tokenized_raw_words.extend(tokens_word)
16
+ tokens_wordtiming_map.append(len(tokens_word))
17
+
18
+ tagged_words = nltk.pos_tag(tokenized_raw_words)
19
+
20
+ grouped_tags = []
21
+
22
+ for k in tokens_wordtiming_map:
23
+ grouped_tags.append(tagged_words[:k])
24
+ tagged_words = tagged_words[k:]
25
+
26
+ tags_only = [tuple([w[1] for w in t]) for t in grouped_tags]
27
+
28
+ wordtimings_with_tags = zip(wt, tags_only)
29
+
30
+ return list(wordtimings_with_tags)
31
+
32
+ def embed_tag_list(tags: list[str]):
33
+ tags_dict = get_upenn_tags_dict()
34
+ eye = np.eye(len(tags_dict))
35
+ return eye[np.array([tags_dict[tag] for tag in tags])]
36
+
37
+ def lookup_tag_list(tags: list[str]):
38
+ tags_dict = get_upenn_tags_dict()
39
+ return np.array([tags_dict[tag] for tag in tags], dtype=int)
40
+
41
+ def tag_training_data(filename: str):
42
+ with open(filename, "r") as f:
43
+ segmented_lines = f.readlines()
44
+
45
+ segmented_lines = [s.strip() for s in segmented_lines if s.strip() != ""]
46
+
47
+ # Regain the full text for more accurate tagging.
48
+ full_text = " ".join(segmented_lines)
49
+
50
+ tokenized_full_text = nltk.word_tokenize(full_text)
51
+ tagged_full_text = nltk.pos_tag(tokenized_full_text)
52
+
53
+ tagged_full_text_copy = tagged_full_text
54
+
55
+ reconstructed_tags = []
56
+
57
+ for line in segmented_lines:
58
+ line_nospace = line.replace(r" ", "")
59
+
60
+ found = False
61
+
62
+ for i in range(len(tagged_full_text_copy)+1):
63
+ rejoined = "".join([x[0] for x in tagged_full_text_copy[:i]])
64
+
65
+ if line_nospace == rejoined:
66
+ found = True
67
+ reconstructed_tags.append(tagged_full_text_copy[:i])
68
+ tagged_full_text_copy = tagged_full_text_copy[i:]
69
+ continue;
70
+
71
+ if found == False:
72
+ print("Panic. Cannot match further.")
73
+ print(f"Was trying to match: {line}")
74
+ print(tagged_full_text_copy)
75
+
76
+ return reconstructed_tags
77
+
78
+ def get_upenn_tags_dict():
79
+ tagger = PerceptronTagger()
80
+
81
+ tags = list(tagger.tagdict.values())
82
+
83
+ # https://www.ling.upenn.edu/courses/Fall_2003/ling001/penn_treebank_pos.html
84
+ tags.extend(["CC", "CD", "DT", "EX", "FW", "IN", "JJ", "JJR", "JJS", "LS", "MD", "NN", "NNS", "NNP", "NNPS", "PDT", "POS", "PRP", "PRP$", "RB", "RBR", "RBS", "RP", "SYM", "TO", "UH", "VB", "VBD", "VBG", "VBN", "VBP", "VBZ", "WDT", "WP", "WP$", "WRB"])
85
+ tags = list(set(tags))
86
+ tags.sort()
87
+ tags.append("BREAK")
88
+
89
+ tags_dict = dict()
90
+
91
+ for index, tag in enumerate(tags):
92
+ tags_dict[tag] = index
93
+
94
+ return tags_dict
95
+
96
+
97
+ def parse_tags(reconstructed_tags):
98
+ """
99
+ Parse reconstructed tags into input/tag datapoint.
100
+ In the original plan, this type of output is suitable for bidirectional LSTM.
101
+
102
+ Input:
103
+ reconstured_tags:
104
+ Tagged segments, from tag_training_data()
105
+ Example: [
106
+ [('You', 'PRP'), ("'re", 'VBP'), ('back', 'RB'), ('again', 'RB'), ('?', '.')],
107
+ [('You', 'PRP'),("'ve", 'VBP'), ('been', 'VBN'), ('consuming', 'VBG'), ('a', 'DT'), ('lot', 'NN'), ('of', 'IN'), ('tech', 'JJ'), ('news', 'NN'), ('lately', 'RB'), ('.', '.')]
108
+ ...
109
+ ]
110
+
111
+ Output:
112
+ (input_tokens, output_tag)
113
+ input_tokens:
114
+ A sequence of tokens, each number corresponds to a type of word.
115
+ Example: [25, 38, 27, 27, 6, 25, 38, 37, 36, 10, 19, 13, 14, 19, 27, 6]
116
+ output_tags:
117
+ A sequence of 0 and 1, indicating whether a break should be inserted AFTER each location.
118
+ Example: [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
119
+ """
120
+ tags_dict = get_upenn_tags_dict()
121
+
122
+ all_tags_sequence = [[y[1] for y in segments] + ['BREAK'] for segments in reconstructed_tags]
123
+ all_tags_sequence = [tag for tags in all_tags_sequence for tag in tags]
124
+
125
+ input_tokens = []
126
+ output_tag = []
127
+ for token in all_tags_sequence:
128
+ if token != 'BREAK':
129
+ input_tokens.append(tags_dict[token])
130
+ output_tag.append(0)
131
+ else:
132
+ output_tag[-1] = 1
133
+
134
+ return input_tokens, output_tag
135
+
136
+ def embed_segments(tagged_segments):
137
+ tags, tags_dict = get_upenn_tags_dict()
138
+
139
+ for index, tag in enumerate(tags):
140
+ tags_dict[tag] = index
141
+
142
+ result_embedding = []
143
+
144
+ classes = len(tags)
145
+ eye = np.eye(classes)
146
+
147
+ for segment in tagged_segments:
148
+ targets = np.array([tags_dict[tag] for word, tag in segment])
149
+ segment_embedding = eye[targets]
150
+
151
+ result_embedding.append(segment_embedding)
152
+ result_embedding.append(np.array([eye[tags_dict["BREAK"]]]))
153
+
154
+ result_embedding = np.concatenate(result_embedding)
155
+
156
+ return result_embedding, tags_dict
157
+
158
+ def window_embedded_segments_rnn(embeddings, tags_dict):
159
+ datapoints = []
160
+ eye = np.eye(len(tags_dict))
161
+
162
+ break_vector = eye[tags_dict["BREAK"]]
163
+
164
+ for i in range(1, embeddings.shape[0]):
165
+ # Should we insert a break BEFORE token i?
166
+ if (embeddings[i] == break_vector).all():
167
+ continue
168
+ else:
169
+ prev_sequence = embeddings[:i]
170
+
171
+ if (prev_sequence[-1] == break_vector).all():
172
+ # It should break here. Remove the break and set tag as 1.
173
+ prev_sequence = prev_sequence[:-1]
174
+ tag = 1
175
+ else:
176
+ # It should not break here.
177
+ tag = 0
178
+
179
+ entire_sequence = np.concatenate((prev_sequence, np.array([embeddings[i]])))
180
+
181
+ datapoints.append((entire_sequence, tag))
182
+ return datapoints
183
+
184
+ def print_dataset(datapoints, tags_dict, tokenized_full_text):
185
+ eye = np.eye(len(tags_dict))
186
+
187
+ break_vector = eye[tags_dict["BREAK"]]
188
+
189
+ for input, tag in datapoints:
190
+ if tag == 1:
191
+ print("[1] ", end='')
192
+ else:
193
+ print("[0] ", end='')
194
+
195
+ count = 0
196
+ for v in input:
197
+ if not (v == break_vector).all():
198
+ count += 1
199
+ # print(input)
200
+ # count = np.count_nonzero(input != break_vector)
201
+ segment = tokenized_full_text[:count]
202
+ print(segment)
203
+
204
+ from stable_whisper.result import Segment # Just for typing
205
+
206
+ def get_indicies(segment: Segment, model, device, threshold):
207
+ word_list = segment.words
208
+ tagged_wordtiming = bind_wordtimings_to_tags(word_list)
209
+
210
+ tag_list = [tag for twt in tagged_wordtiming for tag in twt[1]]
211
+
212
+ tag_per_word = [len(twt[1]) for twt in tagged_wordtiming]
213
+
214
+ embedded_tags = embed_tag_list(tag_list)
215
+ embedded_tags = torch.from_numpy(embedded_tags).float()
216
+
217
+ output = model(embedded_tags[None, :].to(device))
218
+
219
+ list_output = output.detach().cpu().numpy().tolist()[0]
220
+
221
+ current_index = 0
222
+ cut_indicies = []
223
+ for index, tags_count in enumerate(tag_per_word):
224
+ tags = list_output[current_index:current_index+tags_count]
225
+ if max(tags) > threshold:
226
+ cut_indicies.append(index)
227
+ current_index += tags_count
228
+
229
+ return cut_indicies
230
+
231
+ def get_indicies_autoembed(segment: Segment, model, device, threshold):
232
+ word_list = segment.words
233
+ tagged_wordtiming = bind_wordtimings_to_tags(word_list)
234
+
235
+ tag_list = [tag for twt in tagged_wordtiming for tag in twt[1]]
236
+
237
+ tag_per_word = [len(twt[1]) for twt in tagged_wordtiming]
238
+
239
+ embedded_tags = lookup_tag_list(tag_list)
240
+ embedded_tags = torch.from_numpy(embedded_tags).int().to(device)
241
+
242
+ output = model(embedded_tags[None, :].to(device))
243
+
244
+ list_output = output.detach().cpu().numpy().tolist()[0]
245
+
246
+ current_index = 0
247
+ cut_indicies = []
248
+ for index, tags_count in enumerate(tag_per_word):
249
+ tags = list_output[current_index:current_index+tags_count]
250
+ if max(tags) > threshold:
251
+ cut_indicies.append(index)
252
+ current_index += tags_count
253
+
254
+ return cut_indicies