Bikas0's picture
nltk download cancel
beef15f verified
raw
history blame
3.05 kB
from flask import Flask, request, render_template, jsonify
import re
import nltk
import torch
from pathlib import Path
# Define the device if using GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
from transformers import pipeline
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
# nltk.download('punkt')
# nltk.download('wordnet')
# # Ensure NLTK data is downloaded
# nltk.download('punkt', download_dir=Path('/app/nltk_data'))
# nltk.download('wordnet', download_dir=Path('/app/nltk_data'))
# Ensure NLTK uses the correct data path
nltk.data.path.append(os.getenv('NLTK_DATA'))
app = Flask(__name__)
tokenizer = AutoTokenizer.from_pretrained(Path("summary/tokenizer"))
model_name = "summary/pegasus-samsum-model"
def remove_spaces_before_punctuation(text):
pattern = re.compile(r'(\s+)([.,;!?])')
result = pattern.sub(r'\2', text)
result = re.sub(r'\[|\]', '', result)
return result
def replace_pronouns(text):
# Replace "they" with "he" or "she" based on context
text = re.sub(r'\bthey\b', 'He/She', text, flags=re.IGNORECASE)
text = re.sub(r'\b(are|have|were)\b', lambda x: {'are': 'is', 'have': 'has', 'were': 'was'}[x.group()], text)
return text
def clean_and_lemmatize(text):
# Remove digits, symbols, punctuation marks, and newline characters
text = re.sub(r'\d+', '', text)
text = re.sub(r'[^\w\s,-]', '', text.replace('\n', ''))
# Tokenize the text
tokens = word_tokenize(text.lower())
# Initialize lemmatizer
lemmatizer = WordNetLemmatizer()
# Lemmatize each token and join back into a sentence
lemmatized_text = ' '.join([lemmatizer.lemmatize(token) for token in tokens])
return lemmatized_text
@app.route('/summarize', methods=['POST'])
def summarize():
# Get the input text from the request
input_text = request.form['input_text']
# Tokenize the input text
tokens_org_text = tokenizer.tokenize(input_text)
sequence_length_org_text = len(tokens_org_text)
input_text = clean_and_lemmatize(input_text)
tokens = tokenizer.tokenize(input_text)
sequence_length = len(tokens)
if sequence_length >= 1024:
return jsonify({'error': 'Input text exceeds maximum token length of 1023.'})
# Initialize model variable
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
gen_kwargs = {"length_penalty": 0.8, "num_beams": 8, "max_length": 128}
pipe = pipeline("summarization", model=model, tokenizer=tokenizer, device=device)
text = pipe(input_text, **gen_kwargs)[0]["summary_text"]
output_text = replace_pronouns(remove_spaces_before_punctuation(text))
# Clear the GPU cache
torch.cuda.empty_cache()
# Return the summary
return jsonify({'summary': output_text})
@app.route('/')
def index():
return render_template('index.html')
if __name__ == '__main__':
app.run(host='0.0.0.0', debug=True, port=7860) # This is Host Port