aus10powell commited on
Commit
9191bf8
1 Parent(s): b6ab90a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -4
app.py CHANGED
@@ -21,12 +21,14 @@ import scripts.sentiment as sentiment
21
  import scripts.twitter_scraper as ts
22
  import scripts.utils as utils
23
  from scripts import generative
 
24
 
25
  logging.basicConfig(level=logging.INFO)
26
 
27
  app = FastAPI()
28
  templates = Jinja2Templates(directory="templates")
29
  app.mount("/static", StaticFiles(directory="static"), name="static")
 
30
  # Construct absolute path to models folder
31
  models_path = os.path.abspath("models")
32
 
@@ -168,19 +170,46 @@ async def get_sentiment(username: str) -> Dict[str, Dict[str, float]]:
168
  @app.post("/api/generate")
169
  # async def generate_text(account: str, text: str):
170
  async def generate_text(request: Request):
171
-
172
  data = await request.json()
173
  print("*"*50)
174
  print("POST Request:")
175
- print(data['account'],data['text'])
 
 
 
176
  generated_text = generative.generate_account_text(
177
  prompt=data['text'], model_dir=os.path.join(models_path, data['account'])
178
  )
179
  # return one example
180
  generated_text = generated_text[0]["generated_text"]
181
- return {"generated_text": generated_text}
182
-
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  # if __name__ == "__main__":
186
  # # uvicorn.run(app, host="0.0.0.0", port=8000)
 
21
  import scripts.twitter_scraper as ts
22
  import scripts.utils as utils
23
  from scripts import generative
24
+ import nltk
25
 
26
  logging.basicConfig(level=logging.INFO)
27
 
28
  app = FastAPI()
29
  templates = Jinja2Templates(directory="templates")
30
  app.mount("/static", StaticFiles(directory="static"), name="static")
31
+
32
  # Construct absolute path to models folder
33
  models_path = os.path.abspath("models")
34
 
 
170
  @app.post("/api/generate")
171
  # async def generate_text(account: str, text: str):
172
  async def generate_text(request: Request):
173
+ print("*"*50)
174
  data = await request.json()
175
  print("*"*50)
176
  print("POST Request:")
177
+
178
+ # Check length of input, if it is greater than 10 tokens, the text is sent off to a summarizer to generate:
179
+
180
+
181
  generated_text = generative.generate_account_text(
182
  prompt=data['text'], model_dir=os.path.join(models_path, data['account'])
183
  )
184
  # return one example
185
  generated_text = generated_text[0]["generated_text"]
 
 
186
 
187
+ ###################################################
188
+ ## Clean up generate text
189
+ # Get rid of final sentence
190
+ sentences = nltk.sent_tokenize(generated_text)
191
+ unique_sentences = set()
192
+ non_duplicate_sentences = []
193
+ for sentence in sentences:
194
+ if sentence not in unique_sentences:
195
+ non_duplicate_sentences.append(sentence)
196
+ unique_sentences.add(sentence)
197
+ final_text = " ".join(non_duplicate_sentences[:-1])
198
+
199
+
200
+ return {"generated_text": final_text}
201
+
202
+ @app.get("/examples1")
203
+ async def read_examples():
204
+ with open("templates/charts/handle_sentiment_breakdown.html") as f:
205
+ html = f.read()
206
+ return HTMLResponse(content=html)
207
+
208
+ @app.get("/examples2")
209
+ async def read_examples():
210
+ with open("templates/charts/handle_sentiment_timesteps.html") as f:
211
+ html = f.read()
212
+ return HTMLResponse(content=html)
213
 
214
  # if __name__ == "__main__":
215
  # # uvicorn.run(app, host="0.0.0.0", port=8000)