mrmft commited on
Commit
0571449
1 Parent(s): 1a942ce

Update kpe_ranker.py

Browse files
Files changed (1) hide show
  1. kpe_ranker.py +10 -2
kpe_ranker.py CHANGED
@@ -3,10 +3,18 @@ import utils
3
  import os
4
  from sentence_transformers import SentenceTransformer
5
  import ranker
 
6
 
7
  class KpeRanker:
8
- def __init__(self):
9
- TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
 
 
 
 
 
 
 
10
  self.kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
11
  self.ranker_transformer = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2', device='cpu')
12
 
 
3
  import os
4
  from sentence_transformers import SentenceTransformer
5
  import ranker
6
+ from huggingface_hub import hf_hub_download
7
 
8
  class KpeRanker:
9
+ def __init__(self):
10
+ model_path = "/root/.cache/huggingface/hub/models--ahdsoft--persian-keyphrase-extraction-model/trained_model_10000.pt"
11
+ if os.path.isfile(file_path):
12
+ TRAINED_MODEL_ADDR = model_path
13
+ else:
14
+ hf_hub_download(repo_id="lysandre/arxiv-nlp", filename="config.json")
15
+ TRAINED_MODEL_ADDR = model_path
16
+
17
+ # TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
18
  self.kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
19
  self.ranker_transformer = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2', device='cpu')
20