Spaces:
Sleeping
Sleeping
import os | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import spacy | |
import re | |
# Set environment variables for writable directories | |
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache' | |
os.environ['MPLCONFIGDIR'] = '/tmp/.matplotlib' | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Load the spaCy models once | |
nlp = spacy.load("en_core_web_sm") | |
nlp_coref = spacy.load("en_coreference_web_trf") | |
REPLACE_PRONOUNS = {"he", "she", "they", "He", "She", "They"} | |
class CorefRequest(BaseModel): | |
text: str | |
main_characters: str | |
def extract_core_name(mention_text, main_characters): | |
words = mention_text.split() | |
for character in main_characters: | |
if character.lower() in mention_text.lower(): | |
return character | |
return words[-1] | |
def calculate_pronoun_density(text): | |
doc = nlp(text) | |
pronoun_count = sum(1 for token in doc if token.pos_ == "PRON" and token.text in REPLACE_PRONOUNS) | |
named_entity_count = sum(1 for ent in doc.ents if ent.label_ == "PERSON") | |
return pronoun_count / max(named_entity_count, 1), named_entity_count | |
def resolve_coreferences_across_text(text, main_characters): | |
doc = nlp_coref(text) | |
coref_mapping = {} | |
for key, cluster in doc.spans.items(): | |
if re.match(r"coref_clusters_*", key): | |
main_mention = cluster[0] | |
core_name = extract_core_name(main_mention.text, main_characters) | |
if core_name in main_characters: | |
for mention in cluster: | |
for token in mention: | |
if token.text in REPLACE_PRONOUNS: | |
core_name_final = core_name if token.text.istitle() else core_name.lower() | |
coref_mapping[token.i] = core_name_final | |
resolved_tokens = [] | |
current_sentence_characters = set() | |
current_sentence = [] | |
for i, token in enumerate(doc): | |
if token.is_sent_start and current_sentence: | |
resolved_tokens.extend(current_sentence) | |
current_sentence_characters.clear() | |
current_sentence = [] | |
if i in coref_mapping: | |
core_name = coref_mapping[i] | |
if core_name not in current_sentence_characters and core_name.lower() not in [t.lower() for t in current_sentence]: | |
current_sentence.append(core_name) | |
current_sentence_characters.add(core_name) | |
else: | |
current_sentence.append(token.text) | |
else: | |
current_sentence.append(token.text) | |
resolved_tokens.extend(current_sentence) | |
resolved_text = " ".join(resolved_tokens) | |
return remove_consecutive_duplicate_phrases(resolved_text) | |
def remove_consecutive_duplicate_phrases(text): | |
words = text.split() | |
i = 0 | |
while i < len(words) - 1: | |
j = i + 1 | |
while j < len(words): | |
if words[i:j] == words[j:j + (j - i)]: | |
del words[j:j + (j - i)] | |
else: | |
j += 1 | |
i += 1 | |
return " ".join(words) | |
def process_text(text, main_characters): | |
pronoun_density, named_entity_count = calculate_pronoun_density(text) | |
min_named_entities = len(main_characters) | |
if pronoun_density > 0: | |
return resolve_coreferences_across_text(text, main_characters) | |
else: | |
return text | |
async def predict(coref_request: CorefRequest): | |
resolved_text = process_text(coref_request.text, coref_request.main_characters.split(",")) | |
if resolved_text: | |
return {"resolved_text": resolved_text} | |
raise HTTPException(status_code=400, detail="Coreference resolution failed") | |
if __name__ == "__main__": | |
port = int(os.environ.get("PORT", 7860)) | |
uvicorn.run(app, host="0.0.0.0", port=port) | |