Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
import logging | |
import os | |
import re | |
import spaces | |
import torch | |
from cleantext import clean | |
import gradio as gr | |
from tqdm.auto import tqdm | |
from transformers import pipeline | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
logging.basicConfig(level=logging.INFO) | |
logging.info(f"torch version:\t{torch.__version__}") | |
# Model names | |
checker_model_name = "textattack/roberta-base-CoLA" | |
corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis" | |
checker = pipeline( | |
"text-classification", | |
checker_model_name, | |
device_map="cuda", | |
) | |
corrector = pipeline( | |
"text2text-generation", | |
corrector_model_name, | |
device_map="cuda", | |
) | |
def split_text(text: str) -> list: | |
# Split the text into sentences using regex | |
sentences = re.split(r"(?<=[^A-Z].[.?]) +(?=[A-Z])", text) | |
# Initialize lists for batching | |
sentence_batches = [] | |
temp_batch = [] | |
# Create batches of 2-3 sentences | |
for sentence in sentences: | |
temp_batch.append(sentence) | |
if len(temp_batch) >= 2 and len(temp_batch) <= 3 or sentence == sentences[-1]: | |
sentence_batches.append(temp_batch) | |
temp_batch = [] | |
return sentence_batches | |
def correct_text(text: str, separator: str = " ") -> str: | |
# Split the text into sentence batches | |
sentence_batches = split_text(text) | |
# Initialize a list to store the corrected text | |
corrected_text = [] | |
# Process each batch | |
for batch in tqdm( | |
sentence_batches, total=len(sentence_batches), desc="correcting text.." | |
): | |
raw_text = " ".join(batch) | |
# Check grammar quality | |
results = checker(raw_text) | |
# Correct text if needed | |
if results[0]["label"] != "LABEL_1" or ( | |
results[0]["label"] == "LABEL_1" and results[0]["score"] < 0.9 | |
): | |
corrected_batch = corrector(raw_text) | |
corrected_text.append(corrected_batch[0]["generated_text"]) | |
else: | |
corrected_text.append(raw_text) | |
# Join the corrected text | |
return separator.join(corrected_text) | |
def update(text: str): | |
# Clean and truncate input text | |
text = clean(text[:4000], lower=False) | |
return correct_text(text) | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# <center>Robust Grammar Correction with FLAN-T5</center>") | |
gr.Markdown( | |
"**Instructions:** Enter the text you want to correct in the textbox below (_text will be truncated to 4000 characters_). Click 'Process' to run." | |
) | |
gr.Markdown( | |
"""Models: | |
- `textattack/roberta-base-CoLA` for grammar quality detection | |
- `pszemraj/flan-t5-large-grammar-synthesis` for grammar correction | |
""" | |
) | |
with gr.Row(): | |
inp = gr.Textbox( | |
label="input", | |
placeholder="Enter text to check & correct", | |
value="I wen to the store yesturday to bye some food. I needd milk, bread, and a few otter things. The store was really crowed and I had a hard time finding everyting I needed. I finaly made it to the check out line and payed for my stuff.", | |
) | |
out = gr.Textbox(label="output", interactive=False) | |
btn = gr.Button("Process") | |
btn.click(fn=update, inputs=inp, outputs=out) | |
gr.Markdown("---") | |
gr.Markdown( | |
"- See the [model card](https://huggingface.co/pszemraj/flan-t5-large-grammar-synthesis) for more info" | |
) | |
# Launch the demo | |
demo.launch(debug=True) |