RanM's picture
Update app.py
7101b19 verified
raw
history blame
3.76 kB
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import spacy
import re
# Set environment variables for writable directories
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
os.environ['MPLCONFIGDIR'] = '/tmp/.matplotlib'
# Initialize FastAPI app
app = FastAPI()
# Load the spaCy models once
nlp = spacy.load("en_core_web_sm")
nlp_coref = spacy.load("en_coreference_web_trf")
REPLACE_PRONOUNS = {"he", "she", "they", "He", "She", "They"}
class CorefRequest(BaseModel):
text: str
main_characters: str
def extract_core_name(mention_text, main_characters):
words = mention_text.split()
for character in main_characters:
if character.lower() in mention_text.lower():
return character
return words[-1]
def calculate_pronoun_density(text):
doc = nlp(text)
pronoun_count = sum(1 for token in doc if token.pos_ == "PRON" and token.text in REPLACE_PRONOUNS)
named_entity_count = sum(1 for ent in doc.ents if ent.label_ == "PERSON")
return pronoun_count / max(named_entity_count, 1), named_entity_count
def resolve_coreferences_across_text(text, main_characters):
doc = nlp_coref(text)
coref_mapping = {}
for key, cluster in doc.spans.items():
if re.match(r"coref_clusters_*", key):
main_mention = cluster[0]
core_name = extract_core_name(main_mention.text, main_characters)
if core_name in main_characters:
for mention in cluster:
for token in mention:
if token.text in REPLACE_PRONOUNS:
core_name_final = core_name if token.text.istitle() else core_name.lower()
coref_mapping[token.i] = core_name_final
resolved_tokens = []
current_sentence_characters = set()
current_sentence = []
for i, token in enumerate(doc):
if token.is_sent_start and current_sentence:
resolved_tokens.extend(current_sentence)
current_sentence_characters.clear()
current_sentence = []
if i in coref_mapping:
core_name = coref_mapping[i]
if core_name not in current_sentence_characters and core_name.lower() not in [t.lower() for t in current_sentence]:
current_sentence.append(core_name)
current_sentence_characters.add(core_name)
else:
current_sentence.append(token.text)
else:
current_sentence.append(token.text)
resolved_tokens.extend(current_sentence)
resolved_text = " ".join(resolved_tokens)
return remove_consecutive_duplicate_phrases(resolved_text)
def remove_consecutive_duplicate_phrases(text):
words = text.split()
i = 0
while i < len(words) - 1:
j = i + 1
while j < len(words):
if words[i:j] == words[j:j + (j - i)]:
del words[j:j + (j - i)]
else:
j += 1
i += 1
return " ".join(words)
def process_text(text, main_characters):
pronoun_density, named_entity_count = calculate_pronoun_density(text)
min_named_entities = len(main_characters)
if pronoun_density > 0:
return resolve_coreferences_across_text(text, main_characters)
else:
return text
@app.post("/predict")
async def predict(coref_request: CorefRequest):
resolved_text = process_text(coref_request.text, coref_request.main_characters.split(","))
if resolved_text:
return {"resolved_text": resolved_text}
raise HTTPException(status_code=400, detail="Coreference resolution failed")
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)