Bikas0 commited on
Commit
bd2904e
·
verified ·
1 Parent(s): faa4c32

update nltk path

Browse files
Files changed (1) hide show
  1. app.py +87 -84
app.py CHANGED
@@ -1,84 +1,87 @@
1
- from flask import Flask, request, render_template, jsonify
2
- import re
3
- import nltk
4
- import torch
5
- from pathlib import Path
6
-
7
- # Define the device if using GPU
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
-
10
- from transformers import pipeline
11
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
- from nltk.tokenize import word_tokenize
13
- from nltk.stem import WordNetLemmatizer
14
-
15
- nltk.download('punkt')
16
- nltk.download('wordnet')
17
-
18
- app = Flask(__name__)
19
-
20
- tokenizer = AutoTokenizer.from_pretrained(Path("summary/tokenizer"))
21
- model_name = "summary/pegasus-samsum-model"
22
-
23
- def remove_spaces_before_punctuation(text):
24
- pattern = re.compile(r'(\s+)([.,;!?])')
25
- result = pattern.sub(r'\2', text)
26
- result = re.sub(r'\[|\]', '', result)
27
- return result
28
-
29
- def replace_pronouns(text):
30
- # Replace "they" with "he" or "she" based on context
31
- text = re.sub(r'\bthey\b', 'He/She', text, flags=re.IGNORECASE)
32
- text = re.sub(r'\b(are|have|were)\b', lambda x: {'are': 'is', 'have': 'has', 'were': 'was'}[x.group()], text)
33
- return text
34
-
35
- def clean_and_lemmatize(text):
36
- # Remove digits, symbols, punctuation marks, and newline characters
37
- text = re.sub(r'\d+', '', text)
38
- text = re.sub(r'[^\w\s,-]', '', text.replace('\n', ''))
39
- # Tokenize the text
40
- tokens = word_tokenize(text.lower())
41
- # Initialize lemmatizer
42
- lemmatizer = WordNetLemmatizer()
43
- # Lemmatize each token and join back into a sentence
44
- lemmatized_text = ' '.join([lemmatizer.lemmatize(token) for token in tokens])
45
- return lemmatized_text
46
-
47
- @app.route('/summarize', methods=['POST'])
48
- def summarize():
49
- # Get the input text from the request
50
- input_text = request.form['input_text']
51
-
52
- # Tokenize the input text
53
- tokens_org_text = tokenizer.tokenize(input_text)
54
- sequence_length_org_text = len(tokens_org_text)
55
-
56
- input_text = clean_and_lemmatize(input_text)
57
- tokens = tokenizer.tokenize(input_text)
58
- sequence_length = len(tokens)
59
-
60
- if sequence_length >= 1024:
61
- return jsonify({'error': 'Input text exceeds maximum token length of 1023.'})
62
-
63
- # Initialize model variable
64
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
65
-
66
- gen_kwargs = {"length_penalty": 0.8, "num_beams": 8, "max_length": 128}
67
- pipe = pipeline("summarization", model=model, tokenizer=tokenizer, device=device)
68
-
69
- text = pipe(input_text, **gen_kwargs)[0]["summary_text"]
70
- output_text = replace_pronouns(remove_spaces_before_punctuation(text))
71
-
72
- # Clear the GPU cache
73
- torch.cuda.empty_cache()
74
-
75
- # Return the summary
76
- return jsonify({'summary': output_text})
77
-
78
- @app.route('/')
79
- def index():
80
- return render_template('index.html')
81
-
82
- if __name__ == '__main__':
83
- app.run(host='0.0.0.0', debug=True, port=7860) # This is Host Port
84
-
 
 
 
 
1
+ from flask import Flask, request, render_template, jsonify
2
+ import re
3
+ import nltk
4
+ import torch
5
+ from pathlib import Path
6
+
7
+ # Define the device if using GPU
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ from transformers import pipeline
11
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
+ from nltk.tokenize import word_tokenize
13
+ from nltk.stem import WordNetLemmatizer
14
+
15
+ # nltk.download('punkt')
16
+ # nltk.download('wordnet')
17
+ # Ensure NLTK data is downloaded
18
+ nltk.download('punkt', download_dir=Path('/app/nltk_data'))
19
+ nltk.download('wordnet', download_dir=Path('/app/nltk_data'))
20
+
21
+ app = Flask(__name__)
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained(Path("summary/tokenizer"))
24
+ model_name = "summary/pegasus-samsum-model"
25
+
26
+ def remove_spaces_before_punctuation(text):
27
+ pattern = re.compile(r'(\s+)([.,;!?])')
28
+ result = pattern.sub(r'\2', text)
29
+ result = re.sub(r'\[|\]', '', result)
30
+ return result
31
+
32
+ def replace_pronouns(text):
33
+ # Replace "they" with "he" or "she" based on context
34
+ text = re.sub(r'\bthey\b', 'He/She', text, flags=re.IGNORECASE)
35
+ text = re.sub(r'\b(are|have|were)\b', lambda x: {'are': 'is', 'have': 'has', 'were': 'was'}[x.group()], text)
36
+ return text
37
+
38
+ def clean_and_lemmatize(text):
39
+ # Remove digits, symbols, punctuation marks, and newline characters
40
+ text = re.sub(r'\d+', '', text)
41
+ text = re.sub(r'[^\w\s,-]', '', text.replace('\n', ''))
42
+ # Tokenize the text
43
+ tokens = word_tokenize(text.lower())
44
+ # Initialize lemmatizer
45
+ lemmatizer = WordNetLemmatizer()
46
+ # Lemmatize each token and join back into a sentence
47
+ lemmatized_text = ' '.join([lemmatizer.lemmatize(token) for token in tokens])
48
+ return lemmatized_text
49
+
50
+ @app.route('/summarize', methods=['POST'])
51
+ def summarize():
52
+ # Get the input text from the request
53
+ input_text = request.form['input_text']
54
+
55
+ # Tokenize the input text
56
+ tokens_org_text = tokenizer.tokenize(input_text)
57
+ sequence_length_org_text = len(tokens_org_text)
58
+
59
+ input_text = clean_and_lemmatize(input_text)
60
+ tokens = tokenizer.tokenize(input_text)
61
+ sequence_length = len(tokens)
62
+
63
+ if sequence_length >= 1024:
64
+ return jsonify({'error': 'Input text exceeds maximum token length of 1023.'})
65
+
66
+ # Initialize model variable
67
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
68
+
69
+ gen_kwargs = {"length_penalty": 0.8, "num_beams": 8, "max_length": 128}
70
+ pipe = pipeline("summarization", model=model, tokenizer=tokenizer, device=device)
71
+
72
+ text = pipe(input_text, **gen_kwargs)[0]["summary_text"]
73
+ output_text = replace_pronouns(remove_spaces_before_punctuation(text))
74
+
75
+ # Clear the GPU cache
76
+ torch.cuda.empty_cache()
77
+
78
+ # Return the summary
79
+ return jsonify({'summary': output_text})
80
+
81
+ @app.route('/')
82
+ def index():
83
+ return render_template('index.html')
84
+
85
+ if __name__ == '__main__':
86
+ app.run(host='0.0.0.0', debug=True, port=7860) # This is Host Port
87
+