Spaces:
Sleeping
Sleeping
File size: 5,940 Bytes
4a35f71 2ab3fc4 bd2904e 2ab3fc4 bd2904e beef15f bd2904e 4a35f71 bd2904e 2ab3fc4 bd2904e 4a35f71 bd2904e 4a35f71 bd2904e 2ab3fc4 4a35f71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
# import os
# from flask import Flask, request, render_template, jsonify
# import re
# import nltk
# import torch
# from pathlib import Path
# from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
# from nltk.tokenize import word_tokenize
# from nltk.stem import WordNetLemmatizer
# # Ensure NLTK uses the correct data path
# nltk.data.path.append(os.getenv('NLTK_DATA'))
# app = Flask(__name__)
# # Ensure the Transformers cache directory is set correctly
# os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE')
# tokenizer = AutoTokenizer.from_pretrained(Path("summary/tokenizer"))
# model_name = "summary/pegasus-samsum-model"
# def remove_spaces_before_punctuation(text):
# pattern = re.compile(r'(\s+)([.,;!?])')
# result = pattern.sub(r'\2', text)
# result = re.sub(r'\[|\]', '', result)
# return result
# def replace_pronouns(text):
# # Replace "they" with "he" or "she" based on context
# text = re.sub(r'\bthey\b', 'He/She', text, flags=re.IGNORECASE)
# text = re.sub(r'\b(are|have|were)\b', lambda x: {'are': 'is', 'have': 'has', 'were': 'was'}[x.group()], text)
# return text
# def clean_and_lemmatize(text):
# # Remove digits, symbols, punctuation marks, and newline characters
# text = re.sub(r'\d+', '', text)
# text = re.sub(r'[^\w\s,-]', '', text.replace('\n', ''))
# # Tokenize the text
# tokens = word_tokenize(text.lower())
# # Initialize lemmatizer
# lemmatizer = WordNetLemmatizer()
# # Lemmatize each token and join back into a sentence
# lemmatized_text = ' '.join([lemmatizer.lemmatize(token) for token in tokens])
# return lemmatized_text
# @app.route('/summarize', methods=['POST'])
# def summarize():
# # Get the input text from the request
# input_text = request.form['input_text']
# # Tokenize the input text
# tokens_org_text = tokenizer.tokenize(input_text)
# sequence_length_org_text = len(tokens_org_text)
# input_text = clean_and_lemmatize(input_text)
# tokens = tokenizer.tokenize(input_text)
# sequence_length = len(tokens)
# if sequence_length >= 1024:
# return jsonify({'error': 'Input text exceeds maximum token length of 1023.'})
# # Initialize model variable
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
# gen_kwargs = {"length_penalty": 0.8, "num_beams": 8, "max_length": 128}
# pipe = pipeline("summarization", model=model, tokenizer=tokenizer, device=device)
# text = pipe(input_text, **gen_kwargs)[0]["summary_text"]
# output_text = replace_pronouns(remove_spaces_before_punctuation(text))
# # Clear the GPU cache
# torch.cuda.empty_cache()
# # Return the summary
# return jsonify({'summary': output_text})
# @app.route('/')
# def index():
# return render_template('index.html')
# if __name__ == '__main__':
# app.run(host='0.0.0.0', debug=True, port=7860)
import os
from flask import Flask, request, render_template, jsonify
import re
import nltk
import torch
from pathlib import Path
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
# Ensure NLTK uses the correct data path
nltk.data.path.append(os.getenv('NLTK_DATA'))
# Define the device if using GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
app = Flask(__name__)
# Ensure the Transformers cache directory is set correctly
os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE')
tokenizer = AutoTokenizer.from_pretrained(Path("summary/tokenizer"))
model_name = "summary/pegasus-samsum-model"
def remove_spaces_before_punctuation(text):
pattern = re.compile(r'(\s+)([.,;!?])')
result = pattern.sub(r'\2', text)
result = re.sub(r'\[|\]', '', result)
return result
def replace_pronouns(text):
# Replace "they" with "he" or "she" based on context
text = re.sub(r'\bthey\b', 'He/She', text, flags=re.IGNORECASE)
text = re.sub(r'\b(are|have|were)\b', lambda x: {'are': 'is', 'have': 'has', 'were': 'was'}[x.group()], text)
return text
def clean_and_lemmatize(text):
# Remove digits, symbols, punctuation marks, and newline characters
text = re.sub(r'\d+', '', text)
text = re.sub(r'[^\w\s,-]', '', text.replace('\n', ''))
# Tokenize the text
tokens = word_tokenize(text.lower())
# Initialize lemmatizer
lemmatizer = WordNetLemmatizer()
# Lemmatize each token and join back into a sentence
lemmatized_text = ' '.join([lemmatizer.lemmatize(token) for token in tokens])
return lemmatized_text
@app.route('/summarize', methods=['POST'])
def summarize():
# Get the input text from the request
input_text = request.form['input_text']
# Tokenize the input text
tokens_org_text = tokenizer.tokenize(input_text)
sequence_length_org_text = len(tokens_org_text)
input_text = clean_and_lemmatize(input_text)
tokens = tokenizer.tokenize(input_text)
sequence_length = len(tokens)
if sequence_length >= 1024:
return jsonify({'error': 'Input text exceeds maximum token length of 1023.'})
# Initialize model variable
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
gen_kwargs = {"length_penalty": 0.8, "num_beams": 8, "max_length": 128}
pipe = pipeline("summarization", model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1)
text = pipe(input_text, **gen_kwargs)[0]["summary_text"]
output_text = replace_pronouns(remove_spaces_before_punctuation(text))
# Clear the GPU cache
if device == "cuda":
torch.cuda.empty_cache()
# Return the summary
return jsonify({'summary': output_text})
@app.route('/')
def index():
return render_template('index.html')
if __name__ == '__main__':
app.run(host='0.0.0.0', debug=True, port=7860)
|