Spaces:
Runtime error
Runtime error
from utils import re_findall | |
from shared import OutputArguments | |
from typing import Optional | |
from segment import ( | |
generate_segments, | |
extract_segment, | |
SAFETY_TOKENS, | |
CustomTokens, | |
word_start, | |
word_end, | |
SegmentationArguments | |
) | |
import preprocess | |
from errors import TranscriptError | |
from model import get_classifier_vectorizer | |
from transformers import ( | |
AutoModelForSeq2SeqLM, | |
AutoTokenizer, | |
HfArgumentParser | |
) | |
from transformers.trainer_utils import get_last_checkpoint | |
from dataclasses import dataclass, field | |
from shared import device | |
import logging | |
def seconds_to_time(seconds): | |
fractional = round(seconds % 1, 3) | |
fractional = '' if fractional == 0 else str(fractional)[1:] | |
h, remainder = divmod(abs(int(seconds)), 3600) | |
m, s = divmod(remainder, 60) | |
return f"{'-' if seconds < 0 else ''}{h:02}:{m:02}:{s:02}{fractional}" | |
class TrainingOutputArguments: | |
model_path: str = field( | |
default=None, | |
metadata={ | |
'help': 'Path to pretrained model used for prediction'} | |
) | |
output_dir: Optional[str] = OutputArguments.__dataclass_fields__[ | |
'output_dir'] | |
def __post_init__(self): | |
if self.model_path is not None: | |
return | |
last_checkpoint = get_last_checkpoint(self.output_dir) | |
if last_checkpoint is not None: | |
self.model_path = last_checkpoint | |
else: | |
raise Exception( | |
'Unable to find model, explicitly set `--model_path`') | |
class PredictArguments(TrainingOutputArguments): | |
video_id: str = field( | |
default=None, | |
metadata={ | |
'help': 'Video to predict sponsorship segments for'} | |
) | |
SPONSOR_MATCH_RE = fr'(?<={CustomTokens.START_SEGMENT.value})\s*_(?P<category>\S+)\s*(?P<text>.*?)\s*(?={CustomTokens.END_SEGMENT.value}|$)' | |
MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3) | |
MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds | |
class ClassifierArguments: | |
classifier_dir: Optional[str] = field( | |
default='classifiers', | |
metadata={ | |
'help': 'The directory that contains the classifier and vectorizer.' | |
} | |
) | |
classifier_file: Optional[str] = field( | |
default='classifier.pickle', | |
metadata={ | |
'help': 'The name of the classifier' | |
} | |
) | |
vectorizer_file: Optional[str] = field( | |
default='vectorizer.pickle', | |
metadata={ | |
'help': 'The name of the vectorizer' | |
} | |
) | |
min_probability: float = field( | |
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'}) | |
def filter_predictions(predictions, classifier_args): # classifier, vectorizer, | |
"""Use classifier to filter predictions""" | |
if not predictions: | |
return predictions | |
classifier, vectorizer = get_classifier_vectorizer(classifier_args) | |
transformed_segments = vectorizer.transform([ | |
preprocess.clean_text(' '.join([x['text'] for x in pred['words']])) | |
for pred in predictions | |
]) | |
probabilities = classifier.predict_proba(transformed_segments) | |
filtered_predictions = [] | |
for prediction, probability in zip(predictions, probabilities): | |
prediction['probability'] = probability[1] | |
if prediction['probability'] >= classifier_args.min_probability: | |
filtered_predictions.append(prediction) | |
# else: | |
# print('removing segment', prediction) | |
return filtered_predictions | |
def predict(video_id, model, tokenizer, segmentation_args, words=None, classifier_args=None): | |
# Allow words to be passed in so that we don't have to get the words if we already have them | |
if words is None: | |
words = preprocess.get_words(video_id) | |
if not words: | |
raise TranscriptError('Unable to retrieve transcript') | |
segments = generate_segments( | |
words, | |
tokenizer, | |
segmentation_args | |
) | |
predictions = segments_to_prediction_times(segments, model, tokenizer) | |
# Add words back to time_ranges | |
for prediction in predictions: | |
# Stores words in the range | |
prediction['words'] = extract_segment( | |
words, prediction['start'], prediction['end']) | |
if classifier_args is not None: | |
predictions = filter_predictions(predictions, classifier_args) | |
return predictions | |
def greedy_match(list, sublist): | |
# Return index and length of longest matching sublist | |
best_i = -1 | |
best_j = -1 | |
best_k = 0 | |
for i in range(len(list)): # Start position in main list | |
for j in range(len(sublist)): # Start position in sublist | |
for k in range(len(sublist)-j, 0, -1): # Width of sublist window | |
if k > best_k and list[i:i+k] == sublist[j:j+k]: | |
best_i, best_j, best_k = i, j, k | |
break # Since window size decreases | |
return best_i, best_j, best_k | |
def predict_sponsor_text(text, model, tokenizer): | |
"""Given a body of text, predict the words which are part of the sponsor""" | |
input_ids = tokenizer( | |
f'{CustomTokens.EXTRACT_SEGMENTS_PREFIX.value} {text}', return_tensors='pt', truncation=True).input_ids.to(device()) | |
# Can't be longer than input length + SAFETY_TOKENS or model input dim | |
max_out_len = min(len(input_ids[0]) + SAFETY_TOKENS, model.model_dim) | |
outputs = model.generate(input_ids, max_length=max_out_len) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
def predict_sponsor_matches(text, model, tokenizer): | |
sponsorship_text = predict_sponsor_text(text, model, tokenizer) | |
if CustomTokens.NO_SEGMENT.value in sponsorship_text: | |
return [] | |
return re_findall(SPONSOR_MATCH_RE, sponsorship_text) | |
def segments_to_prediction_times(segments, model, tokenizer): | |
predicted_time_ranges = [] | |
# TODO pass to model simultaneously, not in for loop | |
# use 2d array for input ids | |
for segment in segments: | |
cleaned_batch = [preprocess.clean_text( | |
word['text']) for word in segment] | |
batch_text = ' '.join(cleaned_batch) | |
matches = predict_sponsor_matches(batch_text, model, tokenizer) | |
for match in matches: | |
matched_text = match['text'].split() | |
# TODO skip if too short | |
i1, j1, k1 = greedy_match( | |
cleaned_batch, matched_text[:MATCH_WINDOW]) | |
i2, j2, k2 = greedy_match( | |
cleaned_batch, matched_text[-MATCH_WINDOW:]) | |
extracted_words = segment[i1:i2+k2] | |
if not extracted_words: | |
continue | |
predicted_time_ranges.append({ | |
'start': word_start(extracted_words[0]), | |
'end': word_end(extracted_words[-1]), | |
'category': match['category'] | |
}) | |
# Necessary to sort matches by start time | |
predicted_time_ranges.sort(key=word_start) | |
# Merge overlapping predictions and sponsorships that are close together | |
# Caused by model having max input size | |
prev_prediction = None | |
final_predicted_time_ranges = [] | |
for range in predicted_time_ranges: | |
start_time = range['start'] | |
end_time = range['end'] | |
if prev_prediction is not None and range['category'] == prev_prediction['category'] and ( | |
start_time <= prev_prediction['end'] <= end_time or start_time - | |
prev_prediction['end'] <= MERGE_TIME_WITHIN | |
): | |
# Ending time of last segment is in this segment or c, so we extend last prediction range | |
final_predicted_time_ranges[-1]['end'] = end_time | |
else: # No overlap, is a new prediction | |
final_predicted_time_ranges.append({ | |
'start': start_time, | |
'end': end_time, | |
'category': range['category'] | |
}) | |
prev_prediction = range | |
return final_predicted_time_ranges | |
def main(): | |
# Test on unseen data | |
logging.getLogger().setLevel(logging.DEBUG) | |
hf_parser = HfArgumentParser(( | |
PredictArguments, | |
SegmentationArguments, | |
ClassifierArguments | |
)) | |
predict_args, segmentation_args, classifier_args = hf_parser.parse_args_into_dataclasses() | |
if predict_args.video_id is None: | |
print('No video ID supplied. Use `--video_id`.') | |
return | |
model = AutoModelForSeq2SeqLM.from_pretrained(predict_args.model_path) | |
model.to(device()) | |
tokenizer = AutoTokenizer.from_pretrained(predict_args.model_path) | |
predict_args.video_id = predict_args.video_id.strip() | |
predictions = predict(predict_args.video_id, model, tokenizer, | |
segmentation_args) # TODO add back , classifier_args=classifier_args | |
video_url = f'https://www.youtube.com/watch?v={predict_args.video_id}' | |
if not predictions: | |
print('No predictions found for', video_url) | |
return | |
print(len(predictions), 'predictions found for', video_url) | |
for index, prediction in enumerate(predictions, start=1): | |
print(f'Prediction #{index}:') | |
print('Text: "', | |
' '.join([w['text'] for w in prediction['words']]), '"', sep='') | |
print('Time:', seconds_to_time( | |
prediction['start']), '-->', seconds_to_time(prediction['end'])) | |
print('Probability:', prediction.get('probability')) | |
print('Category:', prediction.get('category')) | |
print() | |
if __name__ == '__main__': | |
main() | |