File size: 1,483 Bytes
d43521a
 
 
c42314f
d43521a
a184a4d
d43521a
5904f8c
d43521a
3915661
d43521a
ce0ccdb
d43521a
4c3f294
d43521a
d82c758
019744c
478e5c0
d43521a
 
 
 
 
 
 
 
 
 
 
c2c67ea
d43521a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import streamlit as st

st.title("Grammar Corrector")
st.write("Paste or type text, select number of correction options, enter and the machine will present attempts to correct your text's grammar.")

default_text = "In conclusion,if anyone has some problem the customers must be returned."
sent = st.text_area("Text", default_text, height=40)
num_correct_options = st.number_input('Number of Correction Options', min_value=1, max_value=3, value=1, step=1)

from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector')
model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector').to(torch_device)

def correct_grammar(input_text, num_return_sequences=num_correct_options):
  batch = tokenizer([input_text], truncation=True, padding = 'max_length', max_length = 64, return_tensors = 'pt').to(torch_device)
  results = model.generate(**batch, max_length = 64, num_beams = 5, num_return_sequences = num_correct_options, temperature = 1.5)
  
  return results
  
results = correct_grammar(sent, num_correct_options)

generated_options = []
for generated_option_idx, generated_option in enumerate(results):
    
    text = tokenizer.decode(generated_option, clean_up_tokenization_spaces = True, skip_special_tokens = True)
    generated_options.append(text)
    
st.write(generated_options)