Joshua Lochner commited on
Commit
bb58e90
1 Parent(s): f9281a4

Move `seconds_to_time` to shared

Browse files
Files changed (2) hide show
  1. src/predict.py +18 -24
  2. src/shared.py +14 -1
src/predict.py CHANGED
@@ -5,7 +5,8 @@ from typing import Optional
5
  from segment import (
6
  generate_segments,
7
  extract_segment,
8
- SAFETY_TOKENS,
 
9
  CustomTokens,
10
  word_start,
11
  word_end,
@@ -13,7 +14,7 @@ from segment import (
13
  )
14
  import preprocess
15
  from errors import TranscriptError
16
- from model import get_classifier_vectorizer
17
  from transformers import (
18
  AutoModelForSeq2SeqLM,
19
  AutoTokenizer,
@@ -26,25 +27,15 @@ import logging
26
 
27
  import re
28
 
29
-
30
- def seconds_to_time(seconds, remove_leading_zeroes=False):
31
- fractional = round(seconds % 1, 3)
32
- fractional = '' if fractional == 0 else str(fractional)[1:]
33
- h, remainder = divmod(abs(int(seconds)), 3600)
34
- m, s = divmod(remainder, 60)
35
- hms = f'{h:02}:{m:02}:{s:02}'
36
- if remove_leading_zeroes:
37
- hms = re.sub(r'^0(?:0:0?)?', '', hms)
38
- return f"{'-' if seconds < 0 else ''}{hms}{fractional}"
39
-
40
-
41
  @dataclass
42
  class TrainingOutputArguments:
43
 
44
  model_path: str = field(
45
  default=None,
46
  metadata={
47
- 'help': 'Path to pretrained model used for prediction'}
 
48
  )
49
 
50
  output_dir: Optional[str] = OutputArguments.__dataclass_fields__[
@@ -106,7 +97,8 @@ class ClassifierArguments:
106
  default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
107
 
108
 
109
- def filter_and_add_probabilities(predictions, classifier_args): # classifier, vectorizer,
 
110
  """Use classifier to filter predictions"""
111
  if not predictions:
112
  return predictions
@@ -135,7 +127,7 @@ def filter_and_add_probabilities(predictions, classifier_args): # classifier, v
135
  continue # Ignore
136
 
137
  if (prediction['category'] not in predicted_probabilities) \
138
- or (classifier_category is not None and classifier_probability > 0.5): # TODO make param
139
  # Unknown category or we are confident enough to overrule,
140
  # so change category to what was predicted by classifier
141
  prediction['category'] = classifier_category
@@ -175,7 +167,8 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
175
 
176
  # TODO add back
177
  if classifier_args is not None:
178
- predictions = filter_and_add_probabilities(predictions, classifier_args)
 
179
 
180
  return predictions
181
 
@@ -205,8 +198,12 @@ def predict_sponsor_text(text, model, tokenizer):
205
  input_ids = tokenizer(
206
  f'{CustomTokens.EXTRACT_SEGMENTS_PREFIX.value} {text}', return_tensors='pt', truncation=True).input_ids.to(device())
207
 
208
- # Can't be longer than input length + SAFETY_TOKENS or model input dim
209
- max_out_len = min(len(input_ids[0]) + SAFETY_TOKENS, model.model_dim)
 
 
 
 
210
  outputs = model.generate(input_ids, max_length=max_out_len)
211
 
212
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -300,10 +297,7 @@ def main():
300
  print('No video ID supplied. Use `--video_id`.')
301
  return
302
 
303
- model = AutoModelForSeq2SeqLM.from_pretrained(predict_args.model_path)
304
- model.to(device())
305
-
306
- tokenizer = AutoTokenizer.from_pretrained(predict_args.model_path)
307
 
308
  predict_args.video_id = predict_args.video_id.strip()
309
  predictions = predict(predict_args.video_id, model, tokenizer,
 
5
  from segment import (
6
  generate_segments,
7
  extract_segment,
8
+ MIN_SAFETY_TOKENS,
9
+ SAFETY_TOKENS_PERCENTAGE,
10
  CustomTokens,
11
  word_start,
12
  word_end,
 
14
  )
15
  import preprocess
16
  from errors import TranscriptError
17
+ from model import get_classifier_vectorizer, get_model_tokenizer
18
  from transformers import (
19
  AutoModelForSeq2SeqLM,
20
  AutoTokenizer,
 
27
 
28
  import re
29
 
30
+ from shared import seconds_to_time
 
 
 
 
 
 
 
 
 
 
 
31
  @dataclass
32
  class TrainingOutputArguments:
33
 
34
  model_path: str = field(
35
  default=None,
36
  metadata={
37
+ 'help': 'Path to pretrained model used for prediction'
38
+ }
39
  )
40
 
41
  output_dir: Optional[str] = OutputArguments.__dataclass_fields__[
 
97
  default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
98
 
99
 
100
+ # classifier, vectorizer,
101
+ def filter_and_add_probabilities(predictions, classifier_args):
102
  """Use classifier to filter predictions"""
103
  if not predictions:
104
  return predictions
 
127
  continue # Ignore
128
 
129
  if (prediction['category'] not in predicted_probabilities) \
130
+ or (classifier_category is not None and classifier_probability > 0.5): # TODO make param
131
  # Unknown category or we are confident enough to overrule,
132
  # so change category to what was predicted by classifier
133
  prediction['category'] = classifier_category
 
167
 
168
  # TODO add back
169
  if classifier_args is not None:
170
+ predictions = filter_and_add_probabilities(
171
+ predictions, classifier_args)
172
 
173
  return predictions
174
 
 
198
  input_ids = tokenizer(
199
  f'{CustomTokens.EXTRACT_SEGMENTS_PREFIX.value} {text}', return_tensors='pt', truncation=True).input_ids.to(device())
200
 
201
+ max_out_len = round(min(
202
+ max(
203
+ len(input_ids[0])/SAFETY_TOKENS_PERCENTAGE,
204
+ len(input_ids[0]) + MIN_SAFETY_TOKENS
205
+ ),
206
+ model.model_dim))
207
  outputs = model.generate(input_ids, max_length=max_out_len)
208
 
209
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
297
  print('No video ID supplied. Use `--video_id`.')
298
  return
299
 
300
+ model, tokenizer = get_model_tokenizer(predict_args.model_path)
 
 
 
301
 
302
  predict_args.video_id = predict_args.video_id.strip()
303
  predictions = predict(predict_args.video_id, model, tokenizer,
src/shared.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gc
2
  from time import time_ns
3
  import random
@@ -11,6 +12,7 @@ from enum import Enum
11
  START_SEGMENT_TEMPLATE = 'START_{}_TOKEN'
12
  END_SEGMENT_TEMPLATE = 'END_{}_TOKEN'
13
 
 
14
  class CustomTokens(Enum):
15
  EXTRACT_SEGMENTS_PREFIX = 'EXTRACT_SEGMENTS: '
16
 
@@ -29,7 +31,7 @@ class CustomTokens(Enum):
29
  LAUGHTER = '[Laughter]'
30
 
31
  PROFANITY = 'PROFANITY_TOKEN'
32
-
33
  # Segment tokens
34
  NO_SEGMENT = 'NO_SEGMENT_TOKEN'
35
 
@@ -103,6 +105,17 @@ def device():
103
  return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
104
 
105
 
 
 
 
 
 
 
 
 
 
 
 
106
  def reset():
107
  torch.clear_autocast_cache()
108
  torch.cuda.empty_cache()
 
1
+ import re
2
  import gc
3
  from time import time_ns
4
  import random
 
12
  START_SEGMENT_TEMPLATE = 'START_{}_TOKEN'
13
  END_SEGMENT_TEMPLATE = 'END_{}_TOKEN'
14
 
15
+
16
  class CustomTokens(Enum):
17
  EXTRACT_SEGMENTS_PREFIX = 'EXTRACT_SEGMENTS: '
18
 
 
31
  LAUGHTER = '[Laughter]'
32
 
33
  PROFANITY = 'PROFANITY_TOKEN'
34
+
35
  # Segment tokens
36
  NO_SEGMENT = 'NO_SEGMENT_TOKEN'
37
 
 
105
  return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
106
 
107
 
108
+ def seconds_to_time(seconds, remove_leading_zeroes=False):
109
+ fractional = round(seconds % 1, 3)
110
+ fractional = '' if fractional == 0 else str(fractional)[1:]
111
+ h, remainder = divmod(abs(int(seconds)), 3600)
112
+ m, s = divmod(remainder, 60)
113
+ hms = f'{h:02}:{m:02}:{s:02}'
114
+ if remove_leading_zeroes:
115
+ hms = re.sub(r'^0(?:0:0?)?', '', hms)
116
+ return f"{'-' if seconds < 0 else ''}{hms}{fractional}"
117
+
118
+
119
  def reset():
120
  torch.clear_autocast_cache()
121
  torch.cuda.empty_cache()