File size: 3,685 Bytes
7028ae7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30b774c
 
7028ae7
215cbb5
804f76a
 
 
215cbb5
7028ae7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7373fb4
7028ae7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import streamlit as st
from multiprocessing import Process
import json
import requests
import time
import os


def start_server():   
    '''Helper to start to service through Unicorn '''
    os.system("uvicorn InferenceServer:app --port 8080 --host 0.0.0.0 --workers 2")

def load_models():
    '''One time loading/ Init of models and starting server as a seperate process'''
    if not is_port_in_use(8080):
        with st.spinner(text="Loading model, please wait..."):
            proc = Process(target=start_server, args=(), daemon=True)
            proc.start()
            while not is_port_in_use(8080):
                time.sleep(1)
            st.success("Model server started.")
    else:
        st.success("Model server already running...")
    st.session_state['models_loaded'] = True

def is_port_in_use(port):
    '''Helper to check if service already running'''
    import socket
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        return s.connect_ex(('0.0.0.0', port)) == 0

if 'models_loaded' not in st.session_state:
    st.session_state['models_loaded'] = False

def get_correction(input_text):
    '''Invokes the inference service'''
    st.markdown(f'##### Corrected text:')
    st.write('')
    correct_request = "http://0.0.0.0:8080/restore?input_sentence="+input_text
    with st.spinner('Wait for it...'):
        correct_response = requests.get(correct_request)
        correct_json = json.loads(correct_response.text)
        corrected_sentence = correct_json["corrected_sentence"]
        result = diff_strings(corrected_sentence,input_text)
    st.markdown(result, unsafe_allow_html=True)

def diff_strings(output_text, input_text):
    '''Highlights corrections'''
    c_text = ""
    for x in output_text.split(" "):
      if x in input_text.split(" "):
        c_text = c_text + x + " "
      else:
        c_text = c_text + '<span style="font-weight:bold; color:rgb(142, 208, 129);">' + x + '</span>' + " "
    return c_text   
        
if __name__ == "__main__":
    
        st.title('Rpunct')
        st.subheader('For Punctuation and Upper Case restoration')
        st.markdown("Spaces for [felflare/bert-restore-punctuation](https://huggingface.co/felflare/bert-restore-punctuation) using [Fork with CPU support](https://github.com/anuragshas/rpunct) | [Original repo](https://github.com/Felflare/rpunct)", unsafe_allow_html=True)
        st.markdown("Model restores the following punctuations -- [! ? . , - : ; ' ] and also the upper-casing of words.")
        st.markdown("Integrate with just few lines of code", unsafe_allow_html=True)
        st.markdown("""
                    ```python 
                    from rpunct import RestorePuncts
                    rpunct = RestorePuncts()
                    rpunct.punctuate('''my name is clara and i live in berkeley california''')
                    ```    
                    """)
        examples = [
                    "my name is clara and i live in berkeley california",
                    "in 2018 cornell researchers built a high-power detector",
                    "lorem ipsum has been the industrys standard dummy text ever since the 1500s when an unknown printer took a galley of type and scrambled it to make a type specimen book"
                    ]
        if not st.session_state['models_loaded']:
            load_models()
        
        input_text = st.selectbox(
        label="Choose an example",
        options=examples
        )
        st.write("(or)")
        input_text = st.text_input(
            label="Input sentence",
            value=input_text
        )
        if input_text.strip(): 
            get_correction(input_text)