RanM commited on
Commit
36069fa
·
verified ·
1 Parent(s): f4a7326

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -12
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
- from flask import Flask, request, jsonify
 
3
  import spacy
4
  import re
5
 
@@ -7,8 +8,8 @@ import re
7
  os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
8
  os.environ['MPLCONFIGDIR'] = '/tmp/.matplotlib'
9
 
10
- # Initialize Flask app
11
- app = Flask(__name__)
12
 
13
  # Load the spaCy models once
14
  nlp = spacy.load("en_core_web_sm")
@@ -16,6 +17,10 @@ nlp_coref = spacy.load("en_coreference_web_trf")
16
 
17
  REPLACE_PRONOUNS = {"he", "she", "they", "He", "She", "They"}
18
 
 
 
 
 
19
  def extract_core_name(mention_text, main_characters):
20
  words = mention_text.split()
21
  for character in main_characters:
@@ -84,14 +89,13 @@ def process_text(text, main_characters):
84
  else:
85
  return text
86
 
87
- # API endpoint to handle coreference resolution
88
- @app.route('/predict', methods=['POST'])
89
- def predict():
90
- data = request.json
91
- text = data.get('text')
92
- main_characters = data.get('main_characters')
93
- resolved_text = process_text(text, main_characters.split(","))
94
- return jsonify({"resolved_text": resolved_text})
95
 
96
  if __name__ == "__main__":
97
- app.run(host="0.0.0.0", port=int(os.getenv("PORT", 7860)))
 
 
1
  import os
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
  import spacy
5
  import re
6
 
 
8
  os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
9
  os.environ['MPLCONFIGDIR'] = '/tmp/.matplotlib'
10
 
11
+ # Initialize FastAPI app
12
+ app = FastAPI()
13
 
14
  # Load the spaCy models once
15
  nlp = spacy.load("en_core_web_sm")
 
17
 
18
  REPLACE_PRONOUNS = {"he", "she", "they", "He", "She", "They"}
19
 
20
+ class CorefRequest(BaseModel):
21
+ text: str
22
+ main_characters: str
23
+
24
  def extract_core_name(mention_text, main_characters):
25
  words = mention_text.split()
26
  for character in main_characters:
 
89
  else:
90
  return text
91
 
92
+ @app.post("/predict")
93
+ async def predict(coref_request: CorefRequest):
94
+ resolved_text = process_text(coref_request.text, coref_request.main_characters.split(","))
95
+ if resolved_text:
96
+ return {"resolved_text": resolved_text}
97
+ raise HTTPException(status_code=400, detail="Coreference resolution failed")
 
 
98
 
99
  if __name__ == "__main__":
100
+ import uvicorn
101
+ uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))