orionweller commited on
Commit
d618113
1 Parent(s): 3e7db0b
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -127,9 +127,12 @@ class RepLlamaModel:
127
  return model
128
 
129
  def encode(self, texts, batch_size=48, **kwargs):
130
- self.model = self.model.cuda()
 
 
 
131
  all_embeddings = []
132
- for i in range(0, len(texts), batch_size):
133
  batch_texts = texts[i:i+batch_size]
134
 
135
  batch_dict = create_batch_dict(self.tokenizer, batch_texts, always_add_eos="last")
@@ -143,7 +146,7 @@ class RepLlamaModel:
143
  logger.info(f"Encoded shape: {embeddings.shape}, Norm of first embedding: {torch.norm(embeddings[0]).item()}")
144
  all_embeddings.append(embeddings.cpu().numpy())
145
 
146
- self.model = self.model.cpu()
147
  return np.concatenate(all_embeddings, axis=0)
148
 
149
  def load_corpus_embeddings(dataset_name):
 
127
  return model
128
 
129
  def encode(self, texts, batch_size=48, **kwargs):
130
+ # if model is not on cuda, put it there
131
+ if self.model.device.type != "cuda":
132
+ self.model = self.model.cuda()
133
+
134
  all_embeddings = []
135
+ for i in tqdm.tqdm(range(0, len(texts), batch_size)):
136
  batch_texts = texts[i:i+batch_size]
137
 
138
  batch_dict = create_batch_dict(self.tokenizer, batch_texts, always_add_eos="last")
 
146
  logger.info(f"Encoded shape: {embeddings.shape}, Norm of first embedding: {torch.norm(embeddings[0]).item()}")
147
  all_embeddings.append(embeddings.cpu().numpy())
148
 
149
+ # self.model = self.model.cpu()
150
  return np.concatenate(all_embeddings, axis=0)
151
 
152
  def load_corpus_embeddings(dataset_name):