from flask import Flask, request, render_template, jsonify import re import nltk import torch from pathlib import Path # Define the device if using GPU device = "cuda" if torch.cuda.is_available() else "cpu" from transformers import pipeline from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from nltk.tokenize import word_tokenize from nltk.stem import WordNetLemmatizer # nltk.download('punkt') # nltk.download('wordnet') # # Ensure NLTK data is downloaded # nltk.download('punkt', download_dir=Path('/app/nltk_data')) # nltk.download('wordnet', download_dir=Path('/app/nltk_data')) # Ensure NLTK uses the correct data path nltk.data.path.append(os.getenv('NLTK_DATA')) app = Flask(__name__) tokenizer = AutoTokenizer.from_pretrained(Path("summary/tokenizer")) model_name = "summary/pegasus-samsum-model" def remove_spaces_before_punctuation(text): pattern = re.compile(r'(\s+)([.,;!?])') result = pattern.sub(r'\2', text) result = re.sub(r'\[|\]', '', result) return result def replace_pronouns(text): # Replace "they" with "he" or "she" based on context text = re.sub(r'\bthey\b', 'He/She', text, flags=re.IGNORECASE) text = re.sub(r'\b(are|have|were)\b', lambda x: {'are': 'is', 'have': 'has', 'were': 'was'}[x.group()], text) return text def clean_and_lemmatize(text): # Remove digits, symbols, punctuation marks, and newline characters text = re.sub(r'\d+', '', text) text = re.sub(r'[^\w\s,-]', '', text.replace('\n', '')) # Tokenize the text tokens = word_tokenize(text.lower()) # Initialize lemmatizer lemmatizer = WordNetLemmatizer() # Lemmatize each token and join back into a sentence lemmatized_text = ' '.join([lemmatizer.lemmatize(token) for token in tokens]) return lemmatized_text @app.route('/summarize', methods=['POST']) def summarize(): # Get the input text from the request input_text = request.form['input_text'] # Tokenize the input text tokens_org_text = tokenizer.tokenize(input_text) sequence_length_org_text = len(tokens_org_text) input_text = clean_and_lemmatize(input_text) tokens = tokenizer.tokenize(input_text) sequence_length = len(tokens) if sequence_length >= 1024: return jsonify({'error': 'Input text exceeds maximum token length of 1023.'}) # Initialize model variable model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) gen_kwargs = {"length_penalty": 0.8, "num_beams": 8, "max_length": 128} pipe = pipeline("summarization", model=model, tokenizer=tokenizer, device=device) text = pipe(input_text, **gen_kwargs)[0]["summary_text"] output_text = replace_pronouns(remove_spaces_before_punctuation(text)) # Clear the GPU cache torch.cuda.empty_cache() # Return the summary return jsonify({'summary': output_text}) @app.route('/') def index(): return render_template('index.html') if __name__ == '__main__': app.run(host='0.0.0.0', debug=True, port=7860) # This is Host Port