Spaces:
Runtime error
Runtime error
Joshua Lochner
commited on
Commit
•
9b9ffd0
1
Parent(s):
bdcc521
Fix prediction and evaluation arguments
Browse files- src/evaluate.py +3 -4
- src/predict.py +15 -7
- src/train.py +1 -3
src/evaluate.py
CHANGED
@@ -7,7 +7,7 @@ from transformers import (
|
|
7 |
from preprocess import DatasetArguments, ProcessedArguments, get_words
|
8 |
from model import get_classifier_vectorizer
|
9 |
from shared import device
|
10 |
-
from predict import ClassifierArguments,
|
11 |
from segment import word_start, word_end, SegmentationArguments, add_labels_to_words
|
12 |
import pandas as pd
|
13 |
from dataclasses import dataclass, field
|
@@ -19,7 +19,7 @@ import random
|
|
19 |
|
20 |
|
21 |
@dataclass
|
22 |
-
class EvaluationArguments:
|
23 |
"""
|
24 |
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
25 |
"""
|
@@ -29,8 +29,7 @@ class EvaluationArguments:
|
|
29 |
'help': 'The number of videos to test on'
|
30 |
}
|
31 |
)
|
32 |
-
|
33 |
-
'model_path']
|
34 |
data_dir: Optional[str] = DatasetArguments.__dataclass_fields__['data_dir']
|
35 |
dataset: Optional[str] = DatasetArguments.__dataclass_fields__[
|
36 |
'validation_file']
|
|
|
7 |
from preprocess import DatasetArguments, ProcessedArguments, get_words
|
8 |
from model import get_classifier_vectorizer
|
9 |
from shared import device
|
10 |
+
from predict import ClassifierArguments, predict, filter_predictions, TrainingOutputArguments
|
11 |
from segment import word_start, word_end, SegmentationArguments, add_labels_to_words
|
12 |
import pandas as pd
|
13 |
from dataclasses import dataclass, field
|
|
|
19 |
|
20 |
|
21 |
@dataclass
|
22 |
+
class EvaluationArguments(TrainingOutputArguments):
|
23 |
"""
|
24 |
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
25 |
"""
|
|
|
29 |
'help': 'The number of videos to test on'
|
30 |
}
|
31 |
)
|
32 |
+
|
|
|
33 |
data_dir: Optional[str] = DatasetArguments.__dataclass_fields__['data_dir']
|
34 |
dataset: Optional[str] = DatasetArguments.__dataclass_fields__[
|
35 |
'validation_file']
|
src/predict.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from shared import OutputArguments
|
2 |
from typing import Optional
|
3 |
from segment import (
|
@@ -21,7 +22,6 @@ from dataclasses import dataclass, field
|
|
21 |
from transformers import HfArgumentParser
|
22 |
from shared import device
|
23 |
import logging
|
24 |
-
from transformers.trainer_utils import get_last_checkpoint
|
25 |
|
26 |
|
27 |
def seconds_to_time(seconds):
|
@@ -31,12 +31,7 @@ def seconds_to_time(seconds):
|
|
31 |
|
32 |
|
33 |
@dataclass
|
34 |
-
class
|
35 |
-
|
36 |
-
video_id: str = field(
|
37 |
-
metadata={
|
38 |
-
'help': 'Video to predict sponsorship segments for'}
|
39 |
-
)
|
40 |
|
41 |
model_path: str = field(
|
42 |
default=None,
|
@@ -59,6 +54,15 @@ class PredictArguments:
|
|
59 |
'Unable to find model, explicitly set `--model_path`')
|
60 |
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
SPONSOR_MATCH_RE = fr'(?<={CustomTokens.START_SPONSOR.value})\s*(.*?)\s*(?={CustomTokens.END_SPONSOR.value}|$)'
|
63 |
|
64 |
MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
|
@@ -252,6 +256,10 @@ def main():
|
|
252 |
))
|
253 |
predict_args, segmentation_args, classifier_args = hf_parser.parse_args_into_dataclasses()
|
254 |
|
|
|
|
|
|
|
|
|
255 |
model = AutoModelForSeq2SeqLM.from_pretrained(predict_args.model_path)
|
256 |
model.to(device())
|
257 |
|
|
|
1 |
+
from transformers.trainer_utils import get_last_checkpoint
|
2 |
from shared import OutputArguments
|
3 |
from typing import Optional
|
4 |
from segment import (
|
|
|
22 |
from transformers import HfArgumentParser
|
23 |
from shared import device
|
24 |
import logging
|
|
|
25 |
|
26 |
|
27 |
def seconds_to_time(seconds):
|
|
|
31 |
|
32 |
|
33 |
@dataclass
|
34 |
+
class TrainingOutputArguments:
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
model_path: str = field(
|
37 |
default=None,
|
|
|
54 |
'Unable to find model, explicitly set `--model_path`')
|
55 |
|
56 |
|
57 |
+
@dataclass
|
58 |
+
class PredictArguments(TrainingOutputArguments):
|
59 |
+
video_id: str = field(
|
60 |
+
default=None,
|
61 |
+
metadata={
|
62 |
+
'help': 'Video to predict sponsorship segments for'}
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
SPONSOR_MATCH_RE = fr'(?<={CustomTokens.START_SPONSOR.value})\s*(.*?)\s*(?={CustomTokens.END_SPONSOR.value}|$)'
|
67 |
|
68 |
MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
|
|
|
256 |
))
|
257 |
predict_args, segmentation_args, classifier_args = hf_parser.parse_args_into_dataclasses()
|
258 |
|
259 |
+
if predict_args.video_id is None:
|
260 |
+
print('No video ID supplied. Use `--video_id`.')
|
261 |
+
return
|
262 |
+
|
263 |
model = AutoModelForSeq2SeqLM.from_pretrained(predict_args.model_path)
|
264 |
model.to(device())
|
265 |
|
src/train.py
CHANGED
@@ -1,10 +1,8 @@
|
|
1 |
from preprocess import load_datasets, DatasetArguments
|
2 |
from predict import ClassifierArguments, SPONSOR_MATCH_RE, DEFAULT_TOKEN_PREFIX
|
3 |
-
from shared import device
|
4 |
-
from shared import GeneralArguments, OutputArguments
|
5 |
from model import ModelArguments
|
6 |
import transformers
|
7 |
-
import logging
|
8 |
from model import get_model, get_tokenizer
|
9 |
import logging
|
10 |
import os
|
|
|
1 |
from preprocess import load_datasets, DatasetArguments
|
2 |
from predict import ClassifierArguments, SPONSOR_MATCH_RE, DEFAULT_TOKEN_PREFIX
|
3 |
+
from shared import device, GeneralArguments, OutputArguments
|
|
|
4 |
from model import ModelArguments
|
5 |
import transformers
|
|
|
6 |
from model import get_model, get_tokenizer
|
7 |
import logging
|
8 |
import os
|