yonikremer commited on
Commit
28a440d
1 Parent(s): df1b7f8

adapted the tests to the changes in the code

Browse files
Files changed (1) hide show
  1. tests.py +1 -8
tests.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import shutil
3
 
4
  import pytest as pytest
5
  from grouped_sampling import GroupedSamplingPipeLine
@@ -29,9 +28,6 @@ def test_get_supported_model_names():
29
 
30
 
31
  def test_on_server_start():
32
- if os.path.exists(HUGGING_FACE_CACHE_DIR):
33
- shutil.rmtree(HUGGING_FACE_CACHE_DIR)
34
- assert not os.path.exists(HUGGING_FACE_CACHE_DIR)
35
  download_useful_models()
36
  assert os.path.exists(HUGGING_FACE_CACHE_DIR)
37
  assert len(os.listdir(HUGGING_FACE_CACHE_DIR)) > 0
@@ -64,10 +60,7 @@ def test_create_pipeline(model_name: str):
64
  assert pipeline is not None
65
  assert pipeline.model_name == model_name
66
  assert pipeline.wrapped_model.group_size == 5
67
- assert pipeline.wrapped_model.end_of_sentence_stop is True
68
- model_folder_name = "models--" + model_name.replace("/", "--")
69
- model_path = os.path.join(HUGGING_FACE_CACHE_DIR, model_folder_name)
70
- shutil.rmtree(model_path)
71
 
72
 
73
  if __name__ == "__main__":
 
1
  import os
 
2
 
3
  import pytest as pytest
4
  from grouped_sampling import GroupedSamplingPipeLine
 
28
 
29
 
30
  def test_on_server_start():
 
 
 
31
  download_useful_models()
32
  assert os.path.exists(HUGGING_FACE_CACHE_DIR)
33
  assert len(os.listdir(HUGGING_FACE_CACHE_DIR)) > 0
 
60
  assert pipeline is not None
61
  assert pipeline.model_name == model_name
62
  assert pipeline.wrapped_model.group_size == 5
63
+ assert pipeline.wrapped_model.end_of_sentence_stop is False
 
 
 
64
 
65
 
66
  if __name__ == "__main__":