Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
-
from
|
|
|
3 |
import spacy
|
4 |
import re
|
5 |
|
@@ -7,8 +8,8 @@ import re
|
|
7 |
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
|
8 |
os.environ['MPLCONFIGDIR'] = '/tmp/.matplotlib'
|
9 |
|
10 |
-
# Initialize
|
11 |
-
app =
|
12 |
|
13 |
# Load the spaCy models once
|
14 |
nlp = spacy.load("en_core_web_sm")
|
@@ -16,6 +17,10 @@ nlp_coref = spacy.load("en_coreference_web_trf")
|
|
16 |
|
17 |
REPLACE_PRONOUNS = {"he", "she", "they", "He", "She", "They"}
|
18 |
|
|
|
|
|
|
|
|
|
19 |
def extract_core_name(mention_text, main_characters):
|
20 |
words = mention_text.split()
|
21 |
for character in main_characters:
|
@@ -84,14 +89,13 @@ def process_text(text, main_characters):
|
|
84 |
else:
|
85 |
return text
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
resolved_text = process_text(text, main_characters.split(","))
|
94 |
-
return jsonify({"resolved_text": resolved_text})
|
95 |
|
96 |
if __name__ == "__main__":
|
97 |
-
|
|
|
|
1 |
import os
|
2 |
+
from fastapi import FastAPI, HTTPException
|
3 |
+
from pydantic import BaseModel
|
4 |
import spacy
|
5 |
import re
|
6 |
|
|
|
8 |
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
|
9 |
os.environ['MPLCONFIGDIR'] = '/tmp/.matplotlib'
|
10 |
|
11 |
+
# Initialize FastAPI app
|
12 |
+
app = FastAPI()
|
13 |
|
14 |
# Load the spaCy models once
|
15 |
nlp = spacy.load("en_core_web_sm")
|
|
|
17 |
|
18 |
REPLACE_PRONOUNS = {"he", "she", "they", "He", "She", "They"}
|
19 |
|
20 |
+
class CorefRequest(BaseModel):
|
21 |
+
text: str
|
22 |
+
main_characters: str
|
23 |
+
|
24 |
def extract_core_name(mention_text, main_characters):
|
25 |
words = mention_text.split()
|
26 |
for character in main_characters:
|
|
|
89 |
else:
|
90 |
return text
|
91 |
|
92 |
+
@app.post("/predict")
|
93 |
+
async def predict(coref_request: CorefRequest):
|
94 |
+
resolved_text = process_text(coref_request.text, coref_request.main_characters.split(","))
|
95 |
+
if resolved_text:
|
96 |
+
return {"resolved_text": resolved_text}
|
97 |
+
raise HTTPException(status_code=400, detail="Coreference resolution failed")
|
|
|
|
|
98 |
|
99 |
if __name__ == "__main__":
|
100 |
+
import uvicorn
|
101 |
+
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))
|