eljanmahammadli commited on
Commit
f75d1f0
·
1 Parent(s): d3fb13e

added writing analysis code

Browse files
Files changed (2) hide show
  1. requirements.txt +4 -1
  2. writing_analysis.py +194 -0
requirements.txt CHANGED
@@ -16,4 +16,7 @@ joblib
16
  evaluate
17
  tensorflow
18
  keras
19
- spacy
 
 
 
 
16
  evaluate
17
  tensorflow
18
  keras
19
+ spacy
20
+ textstat
21
+ plotly
22
+ tqdm
writing_analysis.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re, nltk, spacy, textstat, subprocess
2
+ from nltk import FreqDist
3
+ from nltk.corpus import stopwords
4
+ from nltk.tokenize import word_tokenize, sent_tokenize
5
+ from transformers import GPT2LMHeadModel, GPT2TokenizerFast
6
+ import torch
7
+ from tqdm import tqdm
8
+ import gradio as gr
9
+ import plotly.graph_objects as go
10
+
11
+ nltk.download('stopwords')
12
+ nltk.download('punkt')
13
+ nlp = spacy.load("en_core_web_sm")
14
+ command = ['python', '-m', 'spacy', 'download', 'en_core_web_sm', '-q']
15
+
16
+ # Execute the command
17
+ subprocess.run(command)
18
+
19
+ # for perplexity
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ model_id = "gpt2-large"
22
+ model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
23
+ tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
24
+
25
+ def normalize(value, min_value, max_value):
26
+ normalized_value = ((value - min_value) * 100) / (max_value - min_value)
27
+ return max(0, min(100, normalized_value))
28
+
29
+ # vocabulary richness
30
+ def preprocess_text1(text):
31
+ text = text.lower()
32
+ text = re.sub(r'[^\w\s]', '', text) # remove punctuation
33
+ stop_words = set(stopwords.words('english')) # remove stopwords
34
+ words = [word for word in text.split() if word not in stop_words]
35
+ words = [word for word in words if not word.isdigit()] # remove numbers
36
+ return words
37
+
38
+ def vocabulary_richness_ttr(words):
39
+ unique_words = set(words)
40
+ ttr = len(unique_words) / len(words) * 100
41
+ return ttr
42
+
43
+ def calculate_gunning_fog(text):
44
+ """range 0-20"""
45
+ gunning_fog = textstat.gunning_fog(text)
46
+ return gunning_fog
47
+
48
+ def calculate_automated_readability_index(text):
49
+ """range 1-20"""
50
+ ari = textstat.automated_readability_index(text)
51
+ return ari
52
+
53
+ def calculate_flesch_reading_ease(text):
54
+ """range 0-100"""
55
+ fre = textstat.flesch_reading_ease(text)
56
+ return fre
57
+
58
+ def preprocess_text2(text):
59
+ # tokenize into words and remove punctuation
60
+ sentences = sent_tokenize(text)
61
+ words = [word.lower() for sent in sentences for word in word_tokenize(sent) if word.isalnum()]
62
+ # remove stopwords
63
+ stop_words = set(stopwords.words('english'))
64
+ words = [word for word in words if word not in stop_words]
65
+ return words, sentences
66
+
67
+ def calculate_average_sentence_length(sentences):
68
+ """range 0-40 or 50 based on the histogram"""
69
+ total_words = sum(len(word_tokenize(sent)) for sent in sentences)
70
+ average_sentence_length = total_words / (len(sentences) + 0.0000001)
71
+ return average_sentence_length
72
+
73
+ def calculate_average_word_length(words):
74
+ """range 0-8 based on the histogram"""
75
+ total_characters = sum(len(word) for word in words)
76
+ average_word_length = total_characters / (len(words) + 0.0000001)
77
+ return average_word_length
78
+
79
+ def calculate_max_depth(sent):
80
+ return max(len(list(token.ancestors)) for token in sent)
81
+
82
+ def calculate_syntactic_tree_depth(text):
83
+ """0-10 based on the histogram"""
84
+ doc = nlp(text)
85
+ sentence_depths = [calculate_max_depth(sent) for sent in doc.sents]
86
+ average_depth = sum(sentence_depths) / len(sentence_depths) if sentence_depths else 0
87
+ return average_depth
88
+
89
+ # reference: https://huggingface.co/docs/transformers/perplexity
90
+ def calculate_perplexity(text, stride=512):
91
+ """range 0-30 based on the histogram"""
92
+ encodings = tokenizer(text, return_tensors="pt")
93
+ max_length = model.config.n_positions
94
+ seq_len = encodings.input_ids.size(1)
95
+
96
+ nlls = []
97
+ prev_end_loc = 0
98
+ for begin_loc in tqdm(range(0, seq_len, stride)):
99
+ end_loc = min(begin_loc + max_length, seq_len)
100
+ trg_len = end_loc - prev_end_loc # may be different from stride on last loop
101
+ input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
102
+ target_ids = input_ids.clone()
103
+ target_ids[:, :-trg_len] = -100
104
+
105
+ with torch.no_grad():
106
+ outputs = model(input_ids, labels=target_ids)
107
+ neg_log_likelihood = outputs.loss
108
+
109
+ nlls.append(neg_log_likelihood)
110
+
111
+ prev_end_loc = end_loc
112
+ if end_loc == seq_len:
113
+ break
114
+
115
+ ppl = torch.exp(torch.stack(nlls).mean())
116
+ return ppl.item()
117
+
118
+
119
+ def radar_plot(input_text):
120
+
121
+ # vocanulary richness
122
+ processed_words = preprocess_text1(input_text)
123
+ ttr_value = vocabulary_richness_ttr(processed_words)
124
+
125
+ # readability
126
+ gunning_fog = calculate_gunning_fog(input_text)
127
+ gunning_fog_norm = normalize(gunning_fog, min_value=0, max_value=20)
128
+
129
+ # average sentence length and average word length
130
+ words, sentences = preprocess_text2(input_text)
131
+ average_sentence_length = calculate_average_sentence_length(sentences)
132
+ average_word_length = calculate_average_word_length(words)
133
+ average_sentence_length_norm = normalize(average_sentence_length, min_value=0, max_value=40)
134
+ average_word_length_norm = normalize(average_word_length, min_value=0, max_value=8)
135
+
136
+ # syntactic_tree_depth
137
+ average_tree_depth = calculate_syntactic_tree_depth(input_text)
138
+ average_tree_depth_norm = normalize(average_tree_depth, min_value=0, max_value=10)
139
+
140
+ # perplexity
141
+ perplexity = calculate_perplexity(input_text)
142
+ perplexity_norm = normalize(perplexity, min_value=0, max_value=30)
143
+
144
+ features = {
145
+ "readability": gunning_fog_norm,
146
+ "syntactic tree depth": average_tree_depth_norm,
147
+ "vocabulary richness": ttr_value,
148
+ "perplexity": perplexity_norm,
149
+ "average sentence length": average_sentence_length_norm,
150
+ "average word length": average_word_length_norm,
151
+ }
152
+
153
+ print(features)
154
+
155
+ fig = go.Figure()
156
+
157
+ fig.add_trace(go.Scatterpolar(
158
+ r=list(features.values()),
159
+ theta=list(features.keys()),
160
+ fill='toself',
161
+ name='Radar Plot'
162
+ ))
163
+
164
+ fig.update_layout(
165
+ polar=dict(
166
+ radialaxis=dict(
167
+ visible=True,
168
+ range=[0, 100],
169
+ )),
170
+ showlegend=False,
171
+ # autosize=False,
172
+ # width=600,
173
+ # height=600,
174
+ margin=dict(
175
+ l=10,
176
+ r=20,
177
+ b=10,
178
+ t=10,
179
+ # pad=100
180
+ ),
181
+ )
182
+
183
+ return fig
184
+
185
+ # Gradio Interface
186
+ interface = gr.Interface(
187
+ fn=radar_plot,
188
+ inputs=gr.Textbox(label="Input text"),
189
+ outputs=gr.Plot(label="Radar Plot"),
190
+ title="Writing analysis",
191
+ description="Enter text for writing analysis",
192
+ )
193
+
194
+ interface.launch()