|
import torch |
|
from flask import Flask, render_template, request |
|
from difflib import HtmlDiff |
|
import pandas as pd |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration |
|
app = Flask(__name__) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("grammarly/coedit-large") |
|
model = T5ForConditionalGeneration.from_pretrained("grammarly/coedit-large") |
|
|
|
|
|
custom_dataset_path = "styleguide_words.csv" |
|
custom_dataset = pd.read_csv(custom_dataset_path) |
|
|
|
|
|
replacement_mapping = dict(zip(custom_dataset["Not Allowed"], custom_dataset["Replacement"])) |
|
|
|
@app.route('/') |
|
def index(): |
|
return render_template('index.html') |
|
|
|
@app.route('/correct', methods=['POST']) |
|
def correct(): |
|
text = request.form['text'] |
|
corrected_text = grammar_correction(text) |
|
return render_template('result.html', original_text=text, corrected_text=corrected_text) |
|
|
|
@app.route('/styleguide', methods=['POST']) |
|
def styleguide(): |
|
text = request.form['corrected_text'] |
|
highlighted_text, suggestions = apply_styleguide(text) |
|
return render_template('styleguide.html', corrected_text=text, highlighted_text=highlighted_text, suggestions=suggestions) |
|
|
|
@app.route('/compare', methods=['POST']) |
|
def compare(): |
|
original_text = request.form['original_text'] |
|
final_text = request.form['final_text'] |
|
highlighted_changes = highlight_changes(original_text, final_text) |
|
return render_template('compare.html', original_text=original_text, final_text=final_text, highlighted_changes=highlighted_changes) |
|
|
|
def grammar_correction(text): |
|
|
|
sentences = text.split(". ") |
|
corrected_sentences = [] |
|
for sentence in sentences: |
|
|
|
if sentence.startswith("-") or "_" in sentence: |
|
corrected_sentences.append(sentence) |
|
continue |
|
|
|
input_ids = tokenizer(sentence, return_tensors="pt").input_ids |
|
outputs = model.generate(input_ids, max_length=256) |
|
edited_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
corrected_sentences.append(edited_text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
corrected_text = ". ".join(corrected_sentences) |
|
return corrected_text |
|
|
|
def apply_styleguide(text): |
|
|
|
highlighted_text = text |
|
suggestions = [] |
|
for not_allowed_word, replacement_word in replacement_mapping.items(): |
|
if not_allowed_word in highlighted_text: |
|
highlighted_text = highlighted_text.replace(not_allowed_word, f'<span style="background-color: yellow">{not_allowed_word}</span> ({replacement_word})') |
|
suggestions.append((not_allowed_word, replacement_word)) |
|
return highlighted_text, suggestions |
|
|
|
def highlight_changes(original_text, final_text): |
|
|
|
|
|
diff = HtmlDiff() |
|
highlighted_changes = diff.make_table(original_text.splitlines(), final_text.splitlines(), context=True, numlines=2) |
|
return highlighted_changes |
|
|
|
if __name__ == '__main__': |
|
app.run(debug=True) |