Peter commited on
Commit
b8c7507
1 Parent(s): 950a38f

🧑‍💻 add --test flag

Browse files

Signed-off-by: Peter <74869040+pszemraj@users.noreply.github.com>

Files changed (1) hide show
  1. app.py +13 -0
app.py CHANGED
@@ -179,16 +179,29 @@ def get_parser():
179
  default=False,
180
  help="turn on verbose logging",
181
  )
 
 
 
 
 
 
 
182
  return parser
183
 
184
 
185
  if __name__ == "__main__":
186
  args = get_parser().parse_args()
187
  default_model = str(args.model)
 
 
 
 
 
188
  model_loc = Path(default_model) # if the model is a path, use it
189
  basic_sc = args.basic_sc # whether to use the baseline spellchecker
190
  gram_model = str(args.gram_model)
191
  device = 0 if torch.cuda.is_available() else -1
 
192
  print(f"CUDA avail is {torch.cuda.is_available()}")
193
 
194
  my_chatbot = (
 
179
  default=False,
180
  help="turn on verbose logging",
181
  )
182
+ parser.add_argument(
183
+ "--test",
184
+ action="store_true",
185
+ default=False,
186
+ help="load the smallest model for simple testing",
187
+ )
188
+
189
  return parser
190
 
191
 
192
  if __name__ == "__main__":
193
  args = get_parser().parse_args()
194
  default_model = str(args.model)
195
+ test = args.test
196
+ if test:
197
+ logging.info("loading the smallest model for testing")
198
+ default_model = "ethzanalytics/distilgpt2-tiny-conversational"
199
+
200
  model_loc = Path(default_model) # if the model is a path, use it
201
  basic_sc = args.basic_sc # whether to use the baseline spellchecker
202
  gram_model = str(args.gram_model)
203
  device = 0 if torch.cuda.is_available() else -1
204
+
205
  print(f"CUDA avail is {torch.cuda.is_available()}")
206
 
207
  my_chatbot = (