from typing import Union import sys sys.path.append("..") from params import * from dataset.vocab import Vocab from models.corrector import Corrector from models.model import ModelWrapper from models.util import load_weights from dataset.noise import SynthesizeData from utils.api_utils import correctFunction, postprocessing_result model_name = "tfmwtr" dataset = "binhvq" vocab_path = f'data/{dataset}/{dataset}.vocab.pkl' weight_path = f'data/checkpoints/tfmwtr/locdx.weights.pth' vocab = Vocab("vi") vocab.load_vocab_dict(vocab_path) noiser = SynthesizeData(vocab) model_wrapper = ModelWrapper(f"{model_name}", vocab) corrector = Corrector(model_wrapper) load_weights(corrector.model, weight_path) def correct(string: str): out = correctFunction(string, corrector) result = postprocessing_result(out) ret = [] for r in result: r = [s.strip() for s in r if isinstance(s, str)] if len(r) == 2: ret.append((r[0], r[1])) else: ret.append((r[0], None)) ret.append((" ", None)) ret.pop() print(ret, "RET") return ret import gradio as gr if __name__ == "__main__": css = """ #output { .label { background-color: green !important; } } """ gr.Interface( correct, inputs=gr.Textbox(label="Input", placeholder="Enter text to be corrected here..."), outputs=gr.HighlightedText( label="Output", combine_adjacent=True, show_label=True, elem_id="output" ), theme=gr.themes.Default(), css=css ).launch()