import nltk from nltk.tag import PerceptronTagger from stable_whisper.result import WordTiming import numpy as np import torch def bind_wordtimings_to_tags(wt: list[WordTiming]): raw_words = [w.word for w in wt] tokenized_raw_words = [] tokens_wordtiming_map = [] for word in raw_words: tokens_word = nltk.word_tokenize(word) tokenized_raw_words.extend(tokens_word) tokens_wordtiming_map.append(len(tokens_word)) tagged_words = nltk.pos_tag(tokenized_raw_words) grouped_tags = [] for k in tokens_wordtiming_map: grouped_tags.append(tagged_words[:k]) tagged_words = tagged_words[k:] tags_only = [tuple([w[1] for w in t]) for t in grouped_tags] wordtimings_with_tags = zip(wt, tags_only) return list(wordtimings_with_tags) def embed_tag_list(tags: list[str]): tags_dict = get_upenn_tags_dict() eye = np.eye(len(tags_dict)) return eye[np.array([tags_dict[tag] for tag in tags])] def lookup_tag_list(tags: list[str]): tags_dict = get_upenn_tags_dict() return np.array([tags_dict[tag] for tag in tags], dtype=int) def tag_training_data(filename: str): with open(filename, "r") as f: segmented_lines = f.readlines() segmented_lines = [s.strip() for s in segmented_lines if s.strip() != ""] # Regain the full text for more accurate tagging. full_text = " ".join(segmented_lines) tokenized_full_text = nltk.word_tokenize(full_text) tagged_full_text = nltk.pos_tag(tokenized_full_text) tagged_full_text_copy = tagged_full_text reconstructed_tags = [] for line in segmented_lines: line_nospace = line.replace(r" ", "") found = False for i in range(len(tagged_full_text_copy)+1): rejoined = "".join([x[0] for x in tagged_full_text_copy[:i]]) if line_nospace == rejoined: found = True reconstructed_tags.append(tagged_full_text_copy[:i]) tagged_full_text_copy = tagged_full_text_copy[i:] continue; if found == False: print("Panic. Cannot match further.") print(f"Was trying to match: {line}") print(tagged_full_text_copy) return reconstructed_tags def get_upenn_tags_dict(): tagger = PerceptronTagger() tags = list(tagger.tagdict.values()) # https://www.ling.upenn.edu/courses/Fall_2003/ling001/penn_treebank_pos.html tags.extend(["CC", "CD", "DT", "EX", "FW", "IN", "JJ", "JJR", "JJS", "LS", "MD", "NN", "NNS", "NNP", "NNPS", "PDT", "POS", "PRP", "PRP$", "RB", "RBR", "RBS", "RP", "SYM", "TO", "UH", "VB", "VBD", "VBG", "VBN", "VBP", "VBZ", "WDT", "WP", "WP$", "WRB"]) tags = list(set(tags)) tags.sort() tags.append("BREAK") tags_dict = dict() for index, tag in enumerate(tags): tags_dict[tag] = index return tags_dict def parse_tags(reconstructed_tags): """ Parse reconstructed tags into input/tag datapoint. In the original plan, this type of output is suitable for bidirectional LSTM. Input: reconstured_tags: Tagged segments, from tag_training_data() Example: [ [('You', 'PRP'), ("'re", 'VBP'), ('back', 'RB'), ('again', 'RB'), ('?', '.')], [('You', 'PRP'),("'ve", 'VBP'), ('been', 'VBN'), ('consuming', 'VBG'), ('a', 'DT'), ('lot', 'NN'), ('of', 'IN'), ('tech', 'JJ'), ('news', 'NN'), ('lately', 'RB'), ('.', '.')] ... ] Output: (input_tokens, output_tag) input_tokens: A sequence of tokens, each number corresponds to a type of word. Example: [25, 38, 27, 27, 6, 25, 38, 37, 36, 10, 19, 13, 14, 19, 27, 6] output_tags: A sequence of 0 and 1, indicating whether a break should be inserted AFTER each location. Example: [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1] """ tags_dict = get_upenn_tags_dict() all_tags_sequence = [[y[1] for y in segments] + ['BREAK'] for segments in reconstructed_tags] all_tags_sequence = [tag for tags in all_tags_sequence for tag in tags] input_tokens = [] output_tag = [] for token in all_tags_sequence: if token != 'BREAK': input_tokens.append(tags_dict[token]) output_tag.append(0) else: output_tag[-1] = 1 return input_tokens, output_tag def embed_segments(tagged_segments): tags, tags_dict = get_upenn_tags_dict() for index, tag in enumerate(tags): tags_dict[tag] = index result_embedding = [] classes = len(tags) eye = np.eye(classes) for segment in tagged_segments: targets = np.array([tags_dict[tag] for word, tag in segment]) segment_embedding = eye[targets] result_embedding.append(segment_embedding) result_embedding.append(np.array([eye[tags_dict["BREAK"]]])) result_embedding = np.concatenate(result_embedding) return result_embedding, tags_dict def window_embedded_segments_rnn(embeddings, tags_dict): datapoints = [] eye = np.eye(len(tags_dict)) break_vector = eye[tags_dict["BREAK"]] for i in range(1, embeddings.shape[0]): # Should we insert a break BEFORE token i? if (embeddings[i] == break_vector).all(): continue else: prev_sequence = embeddings[:i] if (prev_sequence[-1] == break_vector).all(): # It should break here. Remove the break and set tag as 1. prev_sequence = prev_sequence[:-1] tag = 1 else: # It should not break here. tag = 0 entire_sequence = np.concatenate((prev_sequence, np.array([embeddings[i]]))) datapoints.append((entire_sequence, tag)) return datapoints def print_dataset(datapoints, tags_dict, tokenized_full_text): eye = np.eye(len(tags_dict)) break_vector = eye[tags_dict["BREAK"]] for input, tag in datapoints: if tag == 1: print("[1] ", end='') else: print("[0] ", end='') count = 0 for v in input: if not (v == break_vector).all(): count += 1 # print(input) # count = np.count_nonzero(input != break_vector) segment = tokenized_full_text[:count] print(segment) from stable_whisper.result import Segment # Just for typing def get_indicies(segment: Segment, model, device, threshold): word_list = segment.words tagged_wordtiming = bind_wordtimings_to_tags(word_list) tag_list = [tag for twt in tagged_wordtiming for tag in twt[1]] tag_per_word = [len(twt[1]) for twt in tagged_wordtiming] embedded_tags = embed_tag_list(tag_list) embedded_tags = torch.from_numpy(embedded_tags).float() output = model(embedded_tags[None, :].to(device)) list_output = output.detach().cpu().numpy().tolist()[0] current_index = 0 cut_indicies = [] for index, tags_count in enumerate(tag_per_word): tags = list_output[current_index:current_index+tags_count] if max(tags) > threshold: cut_indicies.append(index) current_index += tags_count return cut_indicies def get_indicies_autoembed(segment: Segment, model, device, threshold): word_list = segment.words tagged_wordtiming = bind_wordtimings_to_tags(word_list) tag_list = [tag for twt in tagged_wordtiming for tag in twt[1]] tag_per_word = [len(twt[1]) for twt in tagged_wordtiming] embedded_tags = lookup_tag_list(tag_list) embedded_tags = torch.from_numpy(embedded_tags).int().to(device) output = model(embedded_tags[None, :].to(device)) list_output = output.detach().cpu().numpy().tolist()[0] current_index = 0 cut_indicies = [] for index, tags_count in enumerate(tag_per_word): tags = list_output[current_index:current_index+tags_count] if max(tags) > threshold: cut_indicies.append(index) current_index += tags_count return cut_indicies