File size: 14,162 Bytes
537f2b7
 
 
 
a294fb2
b3b69aa
5fbdd3c
 
 
 
 
 
a294fb2
63f1925
a6de017
a294fb2
5fbdd3c
 
 
 
 
 
 
a294fb2
183ba5e
5fbdd3c
 
 
9b9ffd0
5fbdd3c
 
 
 
 
 
 
 
 
a294fb2
 
5fbdd3c
 
 
 
 
 
 
537f2b7
 
 
 
 
 
 
5fbdd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a294fb2
5fbdd3c
 
 
 
 
a294fb2
5fbdd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537f2b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fbdd3c
 
 
a294fb2
5fbdd3c
63f1925
 
5fbdd3c
 
a294fb2
5fbdd3c
a294fb2
5fbdd3c
a294fb2
 
 
 
5fbdd3c
 
 
a294fb2
5fbdd3c
 
 
a294fb2
537f2b7
 
 
 
 
 
 
 
 
 
 
 
a294fb2
537f2b7
 
a294fb2
537f2b7
 
a294fb2
 
5fbdd3c
 
 
 
 
 
 
 
 
 
a294fb2
 
5fbdd3c
537f2b7
 
a294fb2
537f2b7
 
5fbdd3c
 
 
 
 
 
 
3f7ce4e
5fbdd3c
537f2b7
 
 
 
 
 
 
5fbdd3c
537f2b7
 
 
 
5fbdd3c
537f2b7
 
 
 
 
 
5fbdd3c
537f2b7
 
5fbdd3c
537f2b7
 
 
 
 
a294fb2
537f2b7
 
 
 
a294fb2
 
537f2b7
a294fb2
 
 
183ba5e
a294fb2
 
 
 
 
 
 
 
 
 
183ba5e
 
 
 
 
 
 
 
 
 
a294fb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fbdd3c
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
import itertools
import base64
import re
import requests
from model import get_model_tokenizer
from utils import jaccard
from datasets import load_dataset
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    HfArgumentParser
)
from preprocess import DatasetArguments, get_words
from shared import device, GeneralArguments
from predict import ClassifierArguments, predict, TrainingOutputArguments
from segment import extract_segment, word_start, word_end, SegmentationArguments, add_labels_to_words
import pandas as pd
from dataclasses import dataclass, field
from typing import Optional
from tqdm import tqdm
import json
import os
import random
from shared import seconds_to_time
from urllib.parse import quote


@dataclass
class EvaluationArguments(TrainingOutputArguments):
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
    max_videos: Optional[int] = field(
        default=None,
        metadata={
            'help': 'The number of videos to test on'
        }
    )
    start_index: int = field(default=None, metadata={
        'help': 'Video to start the evaluation at.'})
    output_file: Optional[str] = field(
        default='metrics.csv',
        metadata={
            'help': 'Save metrics to output file'
        }
    )

    channel_id: Optional[str] = field(
        default=None,
        metadata={
            'help': 'Used to evaluate a channel'
        }
    )


def attach_predictions_to_sponsor_segments(predictions, sponsor_segments):
    """Attach sponsor segments to closest prediction"""
    for prediction in predictions:
        prediction['best_overlap'] = 0
        prediction['best_sponsorship'] = None

    # Assign predictions to actual (labelled) sponsored segments
    for sponsor_segment in sponsor_segments:
        sponsor_segment['best_overlap'] = 0
        sponsor_segment['best_prediction'] = None

        for prediction in predictions:

            j = jaccard(prediction['start'], prediction['end'],
                        sponsor_segment['start'], sponsor_segment['end'])
            if sponsor_segment['best_overlap'] < j:
                sponsor_segment['best_overlap'] = j
                sponsor_segment['best_prediction'] = prediction

            if prediction['best_overlap'] < j:
                prediction['best_overlap'] = j
                prediction['best_sponsorship'] = sponsor_segment

    return sponsor_segments


def calculate_metrics(labelled_words, predictions):

    metrics = {
        'true_positive': 0,  # Is sponsor, predicted sponsor
        # Is sponsor, predicted not sponsor (i.e., missed it - bad)
        'false_negative': 0,
        # Is not sponsor, predicted sponsor (classified incorectly, not that bad since we do manual checking afterwards)
        'false_positive': 0,
        'true_negative': 0,  # Is not sponsor, predicted not sponsor
    }

    metrics['video_duration'] = word_end(
        labelled_words[-1])-word_start(labelled_words[0])

    for index, word in enumerate(labelled_words):
        if index >= len(labelled_words) - 1:
            continue

        # TODO make sure words with manual transcripts
        duration = labelled_words[index+1]['start'] - word['start']

        predicted_sponsor = False
        for p in predictions:
            # Is in some prediction
            if p['start'] <= word['start'] <= p['end']:
                predicted_sponsor = True
                break

        if predicted_sponsor:
            # total_positive_time += duration
            if word.get('category') is not None:  # Is actual sponsor
                metrics['true_positive'] += duration
            else:
                metrics['false_positive'] += duration
        else:
            # total_negative_time += duration
            if word.get('category') is not None:  # Is actual sponsor
                metrics['false_negative'] += duration
            else:
                metrics['true_negative'] += duration

    # NOTE In cases where we encounter division by 0, we say that the value is 1
    # https://stats.stackexchange.com/a/1775
    # (Precision) TP+FP=0: means that all instances were predicted as negative
    # (Recall)    TP+FN=0: means that there were no positive cases in the input data

    # The fraction of predictions our model got right
    # Can simplify, but use full formula
    z = metrics['true_positive'] + metrics['true_negative'] + \
        metrics['false_positive'] + metrics['false_negative']
    metrics['accuracy'] = (
        (metrics['true_positive'] + metrics['true_negative']) / z) if z > 0 else 1

    # What proportion of positive identifications was actually correct?
    z = metrics['true_positive'] + metrics['false_positive']
    metrics['precision'] = (metrics['true_positive'] / z) if z > 0 else 1

    # What proportion of actual positives was identified correctly?
    z = metrics['true_positive'] + metrics['false_negative']
    metrics['recall'] = (metrics['true_positive'] / z) if z > 0 else 1

    # https://deepai.org/machine-learning-glossary-and-terms/f-score

    s = metrics['precision'] + metrics['recall']
    metrics['f-score'] = (2 * (metrics['precision'] *
                               metrics['recall']) / s) if s > 0 else 0

    return metrics


# Public innertube key (b64 encoded so that it is not incorrectly flagged)
INNERTUBE_KEY = base64.b64decode(
    b'QUl6YVN5QU9fRkoyU2xxVThRNFNURUhMR0NpbHdfWTlfMTFxY1c4').decode()

YT_CONTEXT = {
    'client': {
        'userAgent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36,gzip(gfe)',
        'clientName': 'WEB',
        'clientVersion': '2.20211221.00.00',
    }
}
_YT_INITIAL_DATA_RE = r'(?:window\s*\[\s*["\']ytInitialData["\']\s*\]|ytInitialData)\s*=\s*({.+?})\s*;\s*(?:var\s+meta|</script|\n)'


def get_all_channel_vids(channel_id):
    continuation = None
    while True:
        if continuation is None:
            params = {'list': channel_id.replace('UC', 'UU', 1)}
            response = requests.get(
                'https://www.youtube.com/playlist', params=params)
            items = json.loads(re.search(_YT_INITIAL_DATA_RE, response.text).group(1))['contents']['twoColumnBrowseResultsRenderer']['tabs'][0]['tabRenderer']['content'][
                'sectionListRenderer']['contents'][0]['itemSectionRenderer']['contents'][0]['playlistVideoListRenderer']['contents']
        else:
            params = {'key': INNERTUBE_KEY}
            data = {
                'context': YT_CONTEXT,
                'continuation': continuation
            }
            response = requests.post(
                'https://www.youtube.com/youtubei/v1/browse', params=params, json=data)
            items = response.json()[
                'onResponseReceivedActions'][0]['appendContinuationItemsAction']['continuationItems']

        new_token = None
        for vid in items:
            info = vid.get('playlistVideoRenderer')
            if info:
                yield info['videoId']
                continue

            info = vid.get('continuationItemRenderer')
            if info:
                new_token = info['continuationEndpoint']['continuationCommand']['token']

        if new_token is None:
            break
        continuation = new_token


def main():
    hf_parser = HfArgumentParser((
        EvaluationArguments,
        DatasetArguments,
        SegmentationArguments,
        ClassifierArguments,
        GeneralArguments
    ))

    evaluation_args, dataset_args, segmentation_args, classifier_args, _ = hf_parser.parse_args_into_dataclasses()

    model, tokenizer = get_model_tokenizer(evaluation_args.model_path)

    # # TODO find better way of evaluating videos not trained on
    # dataset = load_dataset('json', data_files=os.path.join(
    #     dataset_args.data_dir, dataset_args.validation_file))['train']
    # video_ids = [row['video_id'] for row in dataset]

    # Load labelled data:
    final_path = os.path.join(
        dataset_args.data_dir, dataset_args.processed_file)

    with open(final_path) as fp:
        final_data = json.load(fp)

    if evaluation_args.channel_id is not None:
        start = evaluation_args.start_index or 0
        end = None if evaluation_args.max_videos is None else start + \
            evaluation_args.max_videos

        video_ids = list(itertools.islice(get_all_channel_vids(
            evaluation_args.channel_id), start, end))
        print('Found', len(video_ids), 'for channel', evaluation_args.channel_id)

    else:
        video_ids = list(final_data.keys())
        random.shuffle(video_ids)

        if evaluation_args.start_index is not None:
            video_ids = video_ids[evaluation_args.start_index:]

        if evaluation_args.max_videos is not None:
            video_ids = video_ids[:evaluation_args.max_videos]

    # TODO option to choose categories

    total_accuracy = 0
    total_precision = 0
    total_recall = 0
    total_fscore = 0

    out_metrics = []

    try:
        with tqdm(video_ids) as progress:
            for video_index, video_id in enumerate(progress):

                progress.set_description(f'Processing {video_id}')

                sponsor_segments = final_data.get(video_id)
                if not sponsor_segments:
                    # TODO remove - parse using whole database
                    continue

                words = get_words(video_id)
                if not words:
                    continue

                # Make predictions
                predictions = predict(video_id, model, tokenizer,
                                      segmentation_args, words, classifier_args)

                if sponsor_segments:
                    labelled_words = add_labels_to_words(
                        words, sponsor_segments)
                    met = calculate_metrics(labelled_words, predictions)
                    met['video_id'] = video_id

                    out_metrics.append(met)

                    total_accuracy += met['accuracy']
                    total_precision += met['precision']
                    total_recall += met['recall']
                    total_fscore += met['f-score']

                    progress.set_postfix({
                        'accuracy': total_accuracy/len(out_metrics),
                        'precision':  total_precision/len(out_metrics),
                        'recall':  total_recall/len(out_metrics),
                        'f-score': total_fscore/len(out_metrics)
                    })

                    labelled_predicted_segments = attach_predictions_to_sponsor_segments(
                        predictions, sponsor_segments)

                    # Identify possible issues:
                    missed_segments = [
                        prediction for prediction in predictions if prediction['best_sponsorship'] is None]
                    incorrect_segments = [
                        seg for seg in labelled_predicted_segments if seg['best_prediction'] is None]

                else:
                    # Not in database (all segments missed)
                    missed_segments = predictions
                    incorrect_segments = None

                if missed_segments or incorrect_segments:
                    print(f'Issues identified for {video_id} (#{video_index})')
                    # Potentially missed segments (model predicted, but not in database)
                    if missed_segments:
                        print(' - Missed segments:')
                        segments_to_submit = []
                        for i, missed_segment in enumerate(missed_segments, start=1):
                            print(f'\t#{i}:', seconds_to_time(
                                missed_segment['start']), '-->', seconds_to_time(missed_segment['end']))
                            print('\t\tText: "', ' '.join(
                                [w['text'] for w in missed_segment['words']]), '"', sep='')
                            print('\t\tCategory:',
                                  missed_segment.get('category'))
                            print('\t\tProbability:',
                                  missed_segment.get('probability'))

                            segments_to_submit.append({
                                'segment': [missed_segment['start'], missed_segment['end']],
                                'category': missed_segment['category'].lower(),
                                'actionType': 'skip'
                            })

                        json_data = quote(json.dumps(segments_to_submit))
                        print(
                            f'\tSubmit: https://www.youtube.com/watch?v={video_id}#segments={json_data}')

                    # Potentially incorrect segments (model didn't predict, but in database)
                    if incorrect_segments:
                        print(' - Incorrect segments:')
                        for i, incorrect_segment in enumerate(incorrect_segments, start=1):
                            print(f'\t#{i}:', seconds_to_time(
                                incorrect_segment['start']), '-->', seconds_to_time(incorrect_segment['end']))

                            seg_words = extract_segment(
                                words, incorrect_segment['start'], incorrect_segment['end'])
                            print('\t\tText: "', ' '.join(
                                [w['text'] for w in seg_words]), '"', sep='')
                            print('\t\tUUID:', incorrect_segment['uuid'])
                            print('\t\tCategory:',
                                  incorrect_segment['category'])
                            print('\t\tVotes:', incorrect_segment['votes'])
                            print('\t\tViews:', incorrect_segment['views'])
                            print('\t\tLocked:', incorrect_segment['locked'])
                    print()

    except KeyboardInterrupt:
        pass

    df = pd.DataFrame(out_metrics)

    df.to_csv(evaluation_args.output_file)
    print(df.mean())


if __name__ == '__main__':
    main()