Bikas0's picture
update app file with cuda
4a35f71 verified
raw
history blame
5.94 kB
# import os
# from flask import Flask, request, render_template, jsonify
# import re
# import nltk
# import torch
# from pathlib import Path
# from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
# from nltk.tokenize import word_tokenize
# from nltk.stem import WordNetLemmatizer
# # Ensure NLTK uses the correct data path
# nltk.data.path.append(os.getenv('NLTK_DATA'))
# app = Flask(__name__)
# # Ensure the Transformers cache directory is set correctly
# os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE')
# tokenizer = AutoTokenizer.from_pretrained(Path("summary/tokenizer"))
# model_name = "summary/pegasus-samsum-model"
# def remove_spaces_before_punctuation(text):
# pattern = re.compile(r'(\s+)([.,;!?])')
# result = pattern.sub(r'\2', text)
# result = re.sub(r'\[|\]', '', result)
# return result
# def replace_pronouns(text):
# # Replace "they" with "he" or "she" based on context
# text = re.sub(r'\bthey\b', 'He/She', text, flags=re.IGNORECASE)
# text = re.sub(r'\b(are|have|were)\b', lambda x: {'are': 'is', 'have': 'has', 'were': 'was'}[x.group()], text)
# return text
# def clean_and_lemmatize(text):
# # Remove digits, symbols, punctuation marks, and newline characters
# text = re.sub(r'\d+', '', text)
# text = re.sub(r'[^\w\s,-]', '', text.replace('\n', ''))
# # Tokenize the text
# tokens = word_tokenize(text.lower())
# # Initialize lemmatizer
# lemmatizer = WordNetLemmatizer()
# # Lemmatize each token and join back into a sentence
# lemmatized_text = ' '.join([lemmatizer.lemmatize(token) for token in tokens])
# return lemmatized_text
# @app.route('/summarize', methods=['POST'])
# def summarize():
# # Get the input text from the request
# input_text = request.form['input_text']
# # Tokenize the input text
# tokens_org_text = tokenizer.tokenize(input_text)
# sequence_length_org_text = len(tokens_org_text)
# input_text = clean_and_lemmatize(input_text)
# tokens = tokenizer.tokenize(input_text)
# sequence_length = len(tokens)
# if sequence_length >= 1024:
# return jsonify({'error': 'Input text exceeds maximum token length of 1023.'})
# # Initialize model variable
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
# gen_kwargs = {"length_penalty": 0.8, "num_beams": 8, "max_length": 128}
# pipe = pipeline("summarization", model=model, tokenizer=tokenizer, device=device)
# text = pipe(input_text, **gen_kwargs)[0]["summary_text"]
# output_text = replace_pronouns(remove_spaces_before_punctuation(text))
# # Clear the GPU cache
# torch.cuda.empty_cache()
# # Return the summary
# return jsonify({'summary': output_text})
# @app.route('/')
# def index():
# return render_template('index.html')
# if __name__ == '__main__':
# app.run(host='0.0.0.0', debug=True, port=7860)
import os
from flask import Flask, request, render_template, jsonify
import re
import nltk
import torch
from pathlib import Path
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
# Ensure NLTK uses the correct data path
nltk.data.path.append(os.getenv('NLTK_DATA'))
# Define the device if using GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
app = Flask(__name__)
# Ensure the Transformers cache directory is set correctly
os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE')
tokenizer = AutoTokenizer.from_pretrained(Path("summary/tokenizer"))
model_name = "summary/pegasus-samsum-model"
def remove_spaces_before_punctuation(text):
pattern = re.compile(r'(\s+)([.,;!?])')
result = pattern.sub(r'\2', text)
result = re.sub(r'\[|\]', '', result)
return result
def replace_pronouns(text):
# Replace "they" with "he" or "she" based on context
text = re.sub(r'\bthey\b', 'He/She', text, flags=re.IGNORECASE)
text = re.sub(r'\b(are|have|were)\b', lambda x: {'are': 'is', 'have': 'has', 'were': 'was'}[x.group()], text)
return text
def clean_and_lemmatize(text):
# Remove digits, symbols, punctuation marks, and newline characters
text = re.sub(r'\d+', '', text)
text = re.sub(r'[^\w\s,-]', '', text.replace('\n', ''))
# Tokenize the text
tokens = word_tokenize(text.lower())
# Initialize lemmatizer
lemmatizer = WordNetLemmatizer()
# Lemmatize each token and join back into a sentence
lemmatized_text = ' '.join([lemmatizer.lemmatize(token) for token in tokens])
return lemmatized_text
@app.route('/summarize', methods=['POST'])
def summarize():
# Get the input text from the request
input_text = request.form['input_text']
# Tokenize the input text
tokens_org_text = tokenizer.tokenize(input_text)
sequence_length_org_text = len(tokens_org_text)
input_text = clean_and_lemmatize(input_text)
tokens = tokenizer.tokenize(input_text)
sequence_length = len(tokens)
if sequence_length >= 1024:
return jsonify({'error': 'Input text exceeds maximum token length of 1023.'})
# Initialize model variable
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
gen_kwargs = {"length_penalty": 0.8, "num_beams": 8, "max_length": 128}
pipe = pipeline("summarization", model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1)
text = pipe(input_text, **gen_kwargs)[0]["summary_text"]
output_text = replace_pronouns(remove_spaces_before_punctuation(text))
# Clear the GPU cache
if device == "cuda":
torch.cuda.empty_cache()
# Return the summary
return jsonify({'summary': output_text})
@app.route('/')
def index():
return render_template('index.html')
if __name__ == '__main__':
app.run(host='0.0.0.0', debug=True, port=7860)