Spaces:
Runtime error
Runtime error
AhdCompnay
commited on
Commit
•
5199291
1
Parent(s):
84b31d5
Update kpe_ranker.py
Browse files- kpe_ranker.py +5 -2
kpe_ranker.py
CHANGED
@@ -10,6 +10,9 @@ class KpeRanker:
|
|
10 |
model_name = os.environ.get("MODEL_NAME")
|
11 |
model_repo = os.environ.get("MODEL_REPO")
|
12 |
model_token = os.environ.get("MODEL_TOKEN")
|
|
|
|
|
|
|
13 |
|
14 |
local_dir = "./"
|
15 |
model_path = os.path.join(local_dir, model_name)
|
@@ -17,8 +20,8 @@ class KpeRanker:
|
|
17 |
hf_hub_download(repo_id=model_repo, filename=model_name, local_dir=local_dir, token=model_token)
|
18 |
TRAINED_MODEL_ADDR = model_path
|
19 |
# TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
|
20 |
-
self.kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model=
|
21 |
-
self.ranker_transformer = SentenceTransformer(
|
22 |
|
23 |
|
24 |
def extract(self, text, count, using_ner, return_sorted):
|
|
|
10 |
model_name = os.environ.get("MODEL_NAME")
|
11 |
model_repo = os.environ.get("MODEL_REPO")
|
12 |
model_token = os.environ.get("MODEL_TOKEN")
|
13 |
+
ner_model = os.environ.get("NER_MODEL")
|
14 |
+
transformer_model = os.environ.get("TRANSFORMER_MODEL")
|
15 |
+
|
16 |
|
17 |
local_dir = "./"
|
18 |
model_path = os.path.join(local_dir, model_name)
|
|
|
20 |
hf_hub_download(repo_id=model_repo, filename=model_name, local_dir=local_dir, token=model_token)
|
21 |
TRAINED_MODEL_ADDR = model_path
|
22 |
# TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
|
23 |
+
self.kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model= ner_model , device='cpu')
|
24 |
+
self.ranker_transformer = SentenceTransformer(transformer_model, device='cpu')
|
25 |
|
26 |
|
27 |
def extract(self, text, count, using_ner, return_sorted):
|