Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from flask import Flask, request, jsonify
|
3 |
+
import spacy
|
4 |
+
import re
|
5 |
+
|
6 |
+
# Set environment variables for writable directories
|
7 |
+
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
|
8 |
+
os.environ['MPLCONFIGDIR'] = '/tmp/.matplotlib'
|
9 |
+
|
10 |
+
# Initialize Flask app
|
11 |
+
app = Flask(__name__)
|
12 |
+
|
13 |
+
# Load the spaCy models once
|
14 |
+
nlp = spacy.load("en_core_web_sm")
|
15 |
+
nlp_coref = spacy.load("en_coreference_web_trf")
|
16 |
+
|
17 |
+
REPLACE_PRONOUNS = {"he", "she", "they", "He", "She", "They"}
|
18 |
+
|
19 |
+
def extract_core_name(mention_text, main_characters):
|
20 |
+
words = mention_text.split()
|
21 |
+
for character in main_characters:
|
22 |
+
if character.lower() in mention_text.lower():
|
23 |
+
return character
|
24 |
+
return words[-1]
|
25 |
+
|
26 |
+
def calculate_pronoun_density(text):
|
27 |
+
doc = nlp(text)
|
28 |
+
pronoun_count = sum(1 for token in doc if token.pos_ == "PRON" and token.text in REPLACE_PRONOUNS)
|
29 |
+
named_entity_count = sum(1 for ent in doc.ents if ent.label_ == "PERSON")
|
30 |
+
return pronoun_count / max(named_entity_count, 1), named_entity_count
|
31 |
+
|
32 |
+
def resolve_coreferences_across_text(text, main_characters):
|
33 |
+
doc = nlp_coref(text)
|
34 |
+
coref_mapping = {}
|
35 |
+
for key, cluster in doc.spans.items():
|
36 |
+
if re.match(r"coref_clusters_*", key):
|
37 |
+
main_mention = cluster[0]
|
38 |
+
core_name = extract_core_name(main_mention.text, main_characters)
|
39 |
+
if core_name in main_characters:
|
40 |
+
for mention in cluster:
|
41 |
+
for token in mention:
|
42 |
+
if token.text in REPLACE_PRONOUNS:
|
43 |
+
core_name_final = core_name if token.text.istitle() else core_name.lower()
|
44 |
+
coref_mapping[token.i] = core_name_final
|
45 |
+
resolved_tokens = []
|
46 |
+
current_sentence_characters = set()
|
47 |
+
current_sentence = []
|
48 |
+
for i, token in enumerate(doc):
|
49 |
+
if token.is_sent_start and current_sentence:
|
50 |
+
resolved_tokens.extend(current_sentence)
|
51 |
+
current_sentence_characters.clear()
|
52 |
+
current_sentence = []
|
53 |
+
if i in coref_mapping:
|
54 |
+
core_name = coref_mapping[i]
|
55 |
+
if core_name not in current_sentence_characters and core_name.lower() not in [t.lower() for t in current_sentence]:
|
56 |
+
current_sentence.append(core_name)
|
57 |
+
current_sentence_characters.add(core_name)
|
58 |
+
else:
|
59 |
+
current_sentence.append(token.text)
|
60 |
+
else:
|
61 |
+
current_sentence.append(token.text)
|
62 |
+
resolved_tokens.extend(current_sentence)
|
63 |
+
resolved_text = " ".join(resolved_tokens)
|
64 |
+
return remove_consecutive_duplicate_phrases(resolved_text)
|
65 |
+
|
66 |
+
def remove_consecutive_duplicate_phrases(text):
|
67 |
+
words = text.split()
|
68 |
+
i = 0
|
69 |
+
while i < len(words) - 1:
|
70 |
+
j = i + 1
|
71 |
+
while j < len(words):
|
72 |
+
if words[i:j] == words[j:j + (j - i)]:
|
73 |
+
del words[j:j + (j - i)]
|
74 |
+
else:
|
75 |
+
j += 1
|
76 |
+
i += 1
|
77 |
+
return " ".join(words)
|
78 |
+
|
79 |
+
def process_text(text, main_characters):
|
80 |
+
pronoun_density, named_entity_count = calculate_pronoun_density(text)
|
81 |
+
min_named_entities = len(main_characters)
|
82 |
+
if pronoun_density > 0:
|
83 |
+
return resolve_coreferences_across_text(text, main_characters)
|
84 |
+
else:
|
85 |
+
return text
|
86 |
+
|
87 |
+
# API endpoint to handle coreference resolution
|
88 |
+
@app.route('/predict', methods=['POST'])
|
89 |
+
def predict():
|
90 |
+
data = request.json
|
91 |
+
text = data.get('text')
|
92 |
+
main_characters = data.get('main_characters')
|
93 |
+
resolved_text = process_text(text, main_characters.split(","))
|
94 |
+
return jsonify({"resolved_text": resolved_text})
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
app.run(host="0.0.0.0", port=int(os.getenv("PORT", 7860)))
|