RanM commited on
Commit
3cef660
·
verified ·
1 Parent(s): 61b177b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from flask import Flask, request, jsonify
3
+ import spacy
4
+ import re
5
+
6
+ # Set environment variables for writable directories
7
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
8
+ os.environ['MPLCONFIGDIR'] = '/tmp/.matplotlib'
9
+
10
+ # Initialize Flask app
11
+ app = Flask(__name__)
12
+
13
+ # Load the spaCy models once
14
+ nlp = spacy.load("en_core_web_sm")
15
+ nlp_coref = spacy.load("en_coreference_web_trf")
16
+
17
+ REPLACE_PRONOUNS = {"he", "she", "they", "He", "She", "They"}
18
+
19
+ def extract_core_name(mention_text, main_characters):
20
+ words = mention_text.split()
21
+ for character in main_characters:
22
+ if character.lower() in mention_text.lower():
23
+ return character
24
+ return words[-1]
25
+
26
+ def calculate_pronoun_density(text):
27
+ doc = nlp(text)
28
+ pronoun_count = sum(1 for token in doc if token.pos_ == "PRON" and token.text in REPLACE_PRONOUNS)
29
+ named_entity_count = sum(1 for ent in doc.ents if ent.label_ == "PERSON")
30
+ return pronoun_count / max(named_entity_count, 1), named_entity_count
31
+
32
+ def resolve_coreferences_across_text(text, main_characters):
33
+ doc = nlp_coref(text)
34
+ coref_mapping = {}
35
+ for key, cluster in doc.spans.items():
36
+ if re.match(r"coref_clusters_*", key):
37
+ main_mention = cluster[0]
38
+ core_name = extract_core_name(main_mention.text, main_characters)
39
+ if core_name in main_characters:
40
+ for mention in cluster:
41
+ for token in mention:
42
+ if token.text in REPLACE_PRONOUNS:
43
+ core_name_final = core_name if token.text.istitle() else core_name.lower()
44
+ coref_mapping[token.i] = core_name_final
45
+ resolved_tokens = []
46
+ current_sentence_characters = set()
47
+ current_sentence = []
48
+ for i, token in enumerate(doc):
49
+ if token.is_sent_start and current_sentence:
50
+ resolved_tokens.extend(current_sentence)
51
+ current_sentence_characters.clear()
52
+ current_sentence = []
53
+ if i in coref_mapping:
54
+ core_name = coref_mapping[i]
55
+ if core_name not in current_sentence_characters and core_name.lower() not in [t.lower() for t in current_sentence]:
56
+ current_sentence.append(core_name)
57
+ current_sentence_characters.add(core_name)
58
+ else:
59
+ current_sentence.append(token.text)
60
+ else:
61
+ current_sentence.append(token.text)
62
+ resolved_tokens.extend(current_sentence)
63
+ resolved_text = " ".join(resolved_tokens)
64
+ return remove_consecutive_duplicate_phrases(resolved_text)
65
+
66
+ def remove_consecutive_duplicate_phrases(text):
67
+ words = text.split()
68
+ i = 0
69
+ while i < len(words) - 1:
70
+ j = i + 1
71
+ while j < len(words):
72
+ if words[i:j] == words[j:j + (j - i)]:
73
+ del words[j:j + (j - i)]
74
+ else:
75
+ j += 1
76
+ i += 1
77
+ return " ".join(words)
78
+
79
+ def process_text(text, main_characters):
80
+ pronoun_density, named_entity_count = calculate_pronoun_density(text)
81
+ min_named_entities = len(main_characters)
82
+ if pronoun_density > 0:
83
+ return resolve_coreferences_across_text(text, main_characters)
84
+ else:
85
+ return text
86
+
87
+ # API endpoint to handle coreference resolution
88
+ @app.route('/predict', methods=['POST'])
89
+ def predict():
90
+ data = request.json
91
+ text = data.get('text')
92
+ main_characters = data.get('main_characters')
93
+ resolved_text = process_text(text, main_characters.split(","))
94
+ return jsonify({"resolved_text": resolved_text})
95
+
96
+ if __name__ == "__main__":
97
+ app.run(host="0.0.0.0", port=int(os.getenv("PORT", 7860)))