hoang1007's picture
Upload 69 files
44db343
from typing import Union
import sys
sys.path.append("..")
from params import *
from dataset.vocab import Vocab
from models.corrector import Corrector
from models.model import ModelWrapper
from models.util import load_weights
from dataset.noise import SynthesizeData
from utils.api_utils import correctFunction, postprocessing_result
model_name = "tfmwtr"
dataset = "binhvq"
vocab_path = f'data/{dataset}/{dataset}.vocab.pkl'
weight_path = f'data/checkpoints/tfmwtr/{dataset}.weights.pth'
vocab = Vocab("vi")
vocab.load_vocab_dict(vocab_path)
noiser = SynthesizeData(vocab)
model_wrapper = ModelWrapper(f"{model_name}", vocab)
corrector = Corrector(model_wrapper)
load_weights(corrector.model, weight_path)
def correct(string: str):
out = correctFunction(string, corrector)
result = postprocessing_result(out)
ret = []
for r in result:
r = [s.strip() for s in r if isinstance(s, str)]
if len(r) == 2:
ret.append((r[0], r[1]))
else:
ret.append((r[0], None))
ret.append((" ", None))
ret.pop()
print(ret, "RET")
return ret
import gradio as gr
if __name__ == "__main__":
css = """
#output {
.label {
background-color: green !important;
}
}
"""
gr.Interface(
correct,
inputs=gr.Textbox(label="Input", placeholder="Enter text to be corrected here..."),
outputs=gr.HighlightedText(
label="Output",
combine_adjacent=True,
show_label=True,
elem_id="output"
),
theme=gr.themes.Default(),
css=css
).launch()