Spaces:
Sleeping
Sleeping
from flask import Flask, request, render_template, jsonify | |
import re | |
import nltk | |
import torch | |
from pathlib import Path | |
# Define the device if using GPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
from transformers import pipeline | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from nltk.tokenize import word_tokenize | |
from nltk.stem import WordNetLemmatizer | |
nltk.download('punkt') | |
nltk.download('wordnet') | |
app = Flask(__name__) | |
tokenizer = AutoTokenizer.from_pretrained(Path("summary/tokenizer")) | |
model_name = "summary/pegasus-samsum-model" | |
def remove_spaces_before_punctuation(text): | |
pattern = re.compile(r'(\s+)([.,;!?])') | |
result = pattern.sub(r'\2', text) | |
result = re.sub(r'\[|\]', '', result) | |
return result | |
def replace_pronouns(text): | |
# Replace "they" with "he" or "she" based on context | |
text = re.sub(r'\bthey\b', 'He/She', text, flags=re.IGNORECASE) | |
text = re.sub(r'\b(are|have|were)\b', lambda x: {'are': 'is', 'have': 'has', 'were': 'was'}[x.group()], text) | |
return text | |
def clean_and_lemmatize(text): | |
# Remove digits, symbols, punctuation marks, and newline characters | |
text = re.sub(r'\d+', '', text) | |
text = re.sub(r'[^\w\s,-]', '', text.replace('\n', '')) | |
# Tokenize the text | |
tokens = word_tokenize(text.lower()) | |
# Initialize lemmatizer | |
lemmatizer = WordNetLemmatizer() | |
# Lemmatize each token and join back into a sentence | |
lemmatized_text = ' '.join([lemmatizer.lemmatize(token) for token in tokens]) | |
return lemmatized_text | |
def summarize(): | |
# Get the input text from the request | |
input_text = request.form['input_text'] | |
# Tokenize the input text | |
tokens_org_text = tokenizer.tokenize(input_text) | |
sequence_length_org_text = len(tokens_org_text) | |
input_text = clean_and_lemmatize(input_text) | |
tokens = tokenizer.tokenize(input_text) | |
sequence_length = len(tokens) | |
if sequence_length >= 1024: | |
return jsonify({'error': 'Input text exceeds maximum token length of 1023.'}) | |
# Initialize model variable | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) | |
gen_kwargs = {"length_penalty": 0.8, "num_beams": 8, "max_length": 128} | |
pipe = pipeline("summarization", model=model, tokenizer=tokenizer, device=device) | |
text = pipe(input_text, **gen_kwargs)[0]["summary_text"] | |
output_text = replace_pronouns(remove_spaces_before_punctuation(text)) | |
# Clear the GPU cache | |
torch.cuda.empty_cache() | |
# Return the summary | |
return jsonify({'summary': output_text}) | |
def index(): | |
return render_template('index.html') | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', debug=True, port=7860) # This is Host Port | |