Joshua Lochner commited on
Commit
9b9ffd0
1 Parent(s): bdcc521

Fix prediction and evaluation arguments

Browse files
Files changed (3) hide show
  1. src/evaluate.py +3 -4
  2. src/predict.py +15 -7
  3. src/train.py +1 -3
src/evaluate.py CHANGED
@@ -7,7 +7,7 @@ from transformers import (
7
  from preprocess import DatasetArguments, ProcessedArguments, get_words
8
  from model import get_classifier_vectorizer
9
  from shared import device
10
- from predict import ClassifierArguments, PredictArguments, predict, filter_predictions
11
  from segment import word_start, word_end, SegmentationArguments, add_labels_to_words
12
  import pandas as pd
13
  from dataclasses import dataclass, field
@@ -19,7 +19,7 @@ import random
19
 
20
 
21
  @dataclass
22
- class EvaluationArguments:
23
  """
24
  Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
25
  """
@@ -29,8 +29,7 @@ class EvaluationArguments:
29
  'help': 'The number of videos to test on'
30
  }
31
  )
32
- model_path: Optional[str] = PredictArguments.__dataclass_fields__[
33
- 'model_path']
34
  data_dir: Optional[str] = DatasetArguments.__dataclass_fields__['data_dir']
35
  dataset: Optional[str] = DatasetArguments.__dataclass_fields__[
36
  'validation_file']
 
7
  from preprocess import DatasetArguments, ProcessedArguments, get_words
8
  from model import get_classifier_vectorizer
9
  from shared import device
10
+ from predict import ClassifierArguments, predict, filter_predictions, TrainingOutputArguments
11
  from segment import word_start, word_end, SegmentationArguments, add_labels_to_words
12
  import pandas as pd
13
  from dataclasses import dataclass, field
 
19
 
20
 
21
  @dataclass
22
+ class EvaluationArguments(TrainingOutputArguments):
23
  """
24
  Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
25
  """
 
29
  'help': 'The number of videos to test on'
30
  }
31
  )
32
+
 
33
  data_dir: Optional[str] = DatasetArguments.__dataclass_fields__['data_dir']
34
  dataset: Optional[str] = DatasetArguments.__dataclass_fields__[
35
  'validation_file']
src/predict.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from shared import OutputArguments
2
  from typing import Optional
3
  from segment import (
@@ -21,7 +22,6 @@ from dataclasses import dataclass, field
21
  from transformers import HfArgumentParser
22
  from shared import device
23
  import logging
24
- from transformers.trainer_utils import get_last_checkpoint
25
 
26
 
27
  def seconds_to_time(seconds):
@@ -31,12 +31,7 @@ def seconds_to_time(seconds):
31
 
32
 
33
  @dataclass
34
- class PredictArguments:
35
-
36
- video_id: str = field(
37
- metadata={
38
- 'help': 'Video to predict sponsorship segments for'}
39
- )
40
 
41
  model_path: str = field(
42
  default=None,
@@ -59,6 +54,15 @@ class PredictArguments:
59
  'Unable to find model, explicitly set `--model_path`')
60
 
61
 
 
 
 
 
 
 
 
 
 
62
  SPONSOR_MATCH_RE = fr'(?<={CustomTokens.START_SPONSOR.value})\s*(.*?)\s*(?={CustomTokens.END_SPONSOR.value}|$)'
63
 
64
  MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
@@ -252,6 +256,10 @@ def main():
252
  ))
253
  predict_args, segmentation_args, classifier_args = hf_parser.parse_args_into_dataclasses()
254
 
 
 
 
 
255
  model = AutoModelForSeq2SeqLM.from_pretrained(predict_args.model_path)
256
  model.to(device())
257
 
 
1
+ from transformers.trainer_utils import get_last_checkpoint
2
  from shared import OutputArguments
3
  from typing import Optional
4
  from segment import (
 
22
  from transformers import HfArgumentParser
23
  from shared import device
24
  import logging
 
25
 
26
 
27
  def seconds_to_time(seconds):
 
31
 
32
 
33
  @dataclass
34
+ class TrainingOutputArguments:
 
 
 
 
 
35
 
36
  model_path: str = field(
37
  default=None,
 
54
  'Unable to find model, explicitly set `--model_path`')
55
 
56
 
57
+ @dataclass
58
+ class PredictArguments(TrainingOutputArguments):
59
+ video_id: str = field(
60
+ default=None,
61
+ metadata={
62
+ 'help': 'Video to predict sponsorship segments for'}
63
+ )
64
+
65
+
66
  SPONSOR_MATCH_RE = fr'(?<={CustomTokens.START_SPONSOR.value})\s*(.*?)\s*(?={CustomTokens.END_SPONSOR.value}|$)'
67
 
68
  MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
 
256
  ))
257
  predict_args, segmentation_args, classifier_args = hf_parser.parse_args_into_dataclasses()
258
 
259
+ if predict_args.video_id is None:
260
+ print('No video ID supplied. Use `--video_id`.')
261
+ return
262
+
263
  model = AutoModelForSeq2SeqLM.from_pretrained(predict_args.model_path)
264
  model.to(device())
265
 
src/train.py CHANGED
@@ -1,10 +1,8 @@
1
  from preprocess import load_datasets, DatasetArguments
2
  from predict import ClassifierArguments, SPONSOR_MATCH_RE, DEFAULT_TOKEN_PREFIX
3
- from shared import device
4
- from shared import GeneralArguments, OutputArguments
5
  from model import ModelArguments
6
  import transformers
7
- import logging
8
  from model import get_model, get_tokenizer
9
  import logging
10
  import os
 
1
  from preprocess import load_datasets, DatasetArguments
2
  from predict import ClassifierArguments, SPONSOR_MATCH_RE, DEFAULT_TOKEN_PREFIX
3
+ from shared import device, GeneralArguments, OutputArguments
 
4
  from model import ModelArguments
5
  import transformers
 
6
  from model import get_model, get_tokenizer
7
  import logging
8
  import os