MedChattRe / app.py
GuysRGithub
Update
a4cc6d9
raw
history blame
3.88 kB
from flask import Flask, request
import requests
import os
import re
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
import subprocess
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
model = AutoModelForSeq2SeqLM.from_pretrained(
"GuysTrans/bart-base-finetuned-xsum")
map_words = {
"Hello and Welcome to 'Ask A Doctor' service": "",
"Hello,": "",
"Hi,": "",
"Hello": "",
"Hi": ""
}
word_remove_sentence = [
"hello",
"hi",
"regards",
"dr.",
"physician",
]
def generate_summary(question, model):
inputs = tokenizer(
question,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt",
)
input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)
outputs = model.generate(
input_ids, attention_mask=attention_mask, max_new_tokens=512)
output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return outputs, output_str
app = Flask(__name__)
FB_API_URL = 'https://graph.facebook.com/v2.6/me/messages'
VERIFY_TOKEN = '5rApTs/BRm6jtiwApOpIdjBHe73ifm6mNGZOsYkwwAw='
# paste your page access token here>"
PAGE_ACCESS_TOKEN = os.environ['PAGE_ACCESS_TOKEN']
def get_bot_response(message):
return post_process(generate_summary(message, model)[1][0])
def verify_webhook(req):
if req.args.get("hub.verify_token") == VERIFY_TOKEN:
return req.args.get("hub.challenge")
else:
return "incorrect"
def respond(sender, message):
"""Formulate a response to the user and
pass it on to a function that sends it."""
response = get_bot_response(message)
send_message(sender, response)
return response
def is_user_message(message):
"""Check if the message is a message from the user"""
return (message.get('message') and
message['message'].get('text') and
not message['message'].get("is_echo"))
@app.route("/webhook", methods=['GET', 'POST'])
def listen():
"""This is the main function flask uses to
listen at the `/webhook` endpoint"""
if request.method == 'GET':
return verify_webhook(request)
if request.method == 'POST':
payload = request.json
event = payload['entry'][0]['messaging']
for x in event:
if is_user_message(x):
text = x['message']['text']
sender_id = x['sender']['id']
respond(sender_id, text)
return "ok"
def send_message(recipient_id, text):
"""Send a response to Facebook"""
payload = {
'message': {
'text': text
},
'recipient': {
'id': recipient_id
},
'notification_type': 'regular'
}
auth = {
'access_token': PAGE_ACCESS_TOKEN
}
response = requests.post(
FB_API_URL,
params=auth,
json=payload
)
return response.json()
@app.route("/webhook/chat", methods=['POST'])
def chat():
payload = request.json
message = payload['message']
response = get_bot_response(message)
return {"message": response}
def post_process(output):
lines = output.split("\n")
for line in lines:
for word in word_remove_sentence:
if word in line.lower():
lines.remove(line)
break
output = "\n".join(lines)
for item in map_words.keys():
output = re.sub(item, map_words[item], output, re.I)
return output
subprocess.Popen(["autossh", "-M", "0", "-o", "StrictHostKeyChecking=no",
"-i", "id_rsa", "-R", "guysmedchatt:80:localhost:7860", "serveo.net"])
# subprocess.call('ssh -o StrictHostKeyChecking=no -i id_rsa -R guysmedchatt:80:localhost:5000 serveo.net', shell=True)