File size: 3,870 Bytes
f74338c
 
 
 
30864aa
 
 
f74338c
30864aa
f74338c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from flask import Flask, render_template, request
from difflib import HtmlDiff
import pandas as pd

import os
#os.environ['HF_HOME'] = '/remote/t3dev4/anmolm/sanchit/mlapp/huggingface'
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration

app = Flask(__name__)

# Load Grammarly Coedit-Large model
tokenizer = AutoTokenizer.from_pretrained("grammarly/coedit-large")
model = T5ForConditionalGeneration.from_pretrained("grammarly/coedit-large")

# Load custom dataset
custom_dataset_path = "styleguide_words.csv"
custom_dataset = pd.read_csv(custom_dataset_path)

# Create a mapping between words to be replaced and their replacements
replacement_mapping = dict(zip(custom_dataset["Not Allowed"], custom_dataset["Replacement"]))

@app.route('/')
def index():
   return render_template('index.html')

@app.route('/correct', methods=['POST'])
def correct():
   text = request.form['text']
   corrected_text = grammar_correction(text)
   return render_template('result.html', original_text=text, corrected_text=corrected_text)

@app.route('/styleguide', methods=['POST'])
def styleguide():
   text = request.form['corrected_text']
   highlighted_text, suggestions = apply_styleguide(text)
   return render_template('styleguide.html', corrected_text=text, highlighted_text=highlighted_text, suggestions=suggestions)

@app.route('/compare', methods=['POST'])
def compare():
   original_text = request.form['original_text']
   final_text = request.form['final_text']
   highlighted_changes = highlight_changes(original_text, final_text)
   return render_template('compare.html', original_text=original_text, final_text=final_text, highlighted_changes=highlighted_changes)

def grammar_correction(text):
   # Split the text into sentences
   sentences = text.split(". ")
   corrected_sentences = []
   for sentence in sentences:
       # Check if the sentence is a command (starts with "-" or contains "_")
       if sentence.startswith("-") or "_" in sentence:
           corrected_sentences.append(sentence)  # Skip the command line
           continue
       # Tokenize input text
       input_ids = tokenizer(sentence, return_tensors="pt").input_ids
       outputs = model.generate(input_ids, max_length=256)
       edited_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
       corrected_sentences.append(edited_text)
       
      #  inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True)
      #  # Generate corrected text using the model
      #  with torch.no_grad():
      #      outputs = model.generate(**inputs)
      #  # Decode and append corrected sentence to list
      #  corrected_sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
      #  corrected_sentences.append(corrected_sentence)
   
   # Join the corrected sentences into a single paragraph
   corrected_text = ". ".join(corrected_sentences)
   return corrected_text

def apply_styleguide(text):
   # Highlight words mentioned in the CSV file and suggest replacements
   highlighted_text = text
   suggestions = []
   for not_allowed_word, replacement_word in replacement_mapping.items():
       if not_allowed_word in highlighted_text:
           highlighted_text = highlighted_text.replace(not_allowed_word, f'<span style="background-color: yellow">{not_allowed_word}</span> ({replacement_word})')
           suggestions.append((not_allowed_word, replacement_word))
   return highlighted_text, suggestions
   
def highlight_changes(original_text, final_text):
   # Function to highlight changes between original and final text
   # You can modify this function as needed
   diff = HtmlDiff()
   highlighted_changes = diff.make_table(original_text.splitlines(), final_text.splitlines(), context=True, numlines=2)
   return highlighted_changes
   
if __name__ == '__main__':
   app.run(debug=True)