ronald commited on
Commit
1239913
1 Parent(s): 76cc74c
Files changed (2) hide show
  1. app.py +2 -2
  2. my_perplexity.py +3 -3
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import evaluate
2
  from evaluate.utils import launch_gradio_widget
3
 
4
- CACHE_DIR="/gfs/team/nlp/users/rcardena/tools/huggingface/evaluate"
5
- module = evaluate.load("my_perplexity", module_type="measurement",cache_dir=CACHE_DIR)
6
  launch_gradio_widget(module)
 
1
  import evaluate
2
  from evaluate.utils import launch_gradio_widget
3
 
4
+ METRICS_CACHE_DIR="/gfs/team/nlp/users/rcardena/tools/huggingface/evaluate"
5
+ module = evaluate.load("my_perplexity", module_type="measurement",cache_dir=METRICS_CACHE_DIR)
6
  launch_gradio_widget(module)
my_perplexity.py CHANGED
@@ -105,7 +105,7 @@ class MyPerplexity(evaluate.Measurement):
105
  )
106
 
107
  def _compute(self, predictions, model_id, batch_size: int = 16, add_start_token: bool = True, device=None):
108
-
109
  if device is not None:
110
  assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu."
111
  if device == "gpu":
@@ -113,12 +113,12 @@ class MyPerplexity(evaluate.Measurement):
113
  else:
114
  device = "cuda" if torch.cuda.is_available() else "cpu"
115
 
116
- model = AutoModelForCausalLM.from_pretrained(model_id,cache_dir=self.cache_dir)
117
  model = model.to(device)
118
 
119
  tokenizer = AutoTokenizer.from_pretrained(
120
  model_id,
121
- cache_dir=self.cache_dir,
122
  use_fast="cnn_dailymail" not in model_id,
123
  )
124
 
 
105
  )
106
 
107
  def _compute(self, predictions, model_id, batch_size: int = 16, add_start_token: bool = True, device=None):
108
+ MODEL_CACHE_DIR="/gfs/team/nlp/users/rcardena/tools/huggingface"
109
  if device is not None:
110
  assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu."
111
  if device == "gpu":
 
113
  else:
114
  device = "cuda" if torch.cuda.is_available() else "cpu"
115
 
116
+ model = AutoModelForCausalLM.from_pretrained(model_id,cache_dir=MODEL_CACHE_DIR)
117
  model = model.to(device)
118
 
119
  tokenizer = AutoTokenizer.from_pretrained(
120
  model_id,
121
+ cache_dir=MODEL_CACHE_DIR,
122
  use_fast="cnn_dailymail" not in model_id,
123
  )
124