File size: 2,159 Bytes
18cb05d
 
 
d1eaeb1
37361c0
 
214f1cd
 
d1eaeb1
18cb05d
 
 
 
 
d1eaeb1
18cb05d
d1eaeb1
18cb05d
 
 
 
81fdd06
18cb05d
 
 
 
 
 
 
 
 
 
 
 
d1eaeb1
18cb05d
 
 
 
 
 
 
 
 
 
2d6b756
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# import the module
import streamlit as st
from transformers import pipeline
#Import the model
model = pipeline(task="fill-mask",
                 model="MUmairAB/bert-based-MaskedLM")

#Typically, the model should be imported within a function. However, in this case, we are downloading it outside the function to avoid a significant delay that could annoy the user when downloading it inside the main function. By loading the model at this point, it will be downloaded when the app runs, and the user will overlook this initial loading time, as opposed to experiencing a delay after entering the input.



#This function accepts the masked text like: "How are [MASK]"
# and feeds this text to the model and prints the output in which [MASK] is filled with the appropriate word.
def print_the_mask(text):
  
    #Apply the model
    model_out = model(text)

    #First sort the list of dictionaries according to the score
    model_out = sorted(model_out, key=lambda x: x['score'],reverse=True)
    for sub_dict in model_out:
        st.success(sub_dict["sequence"])


#The main function that will be executed when this file is executed
def main():
    # Set the title
    st.title("Masked Language Model App")
    st.write("Created by: [Umair Akram](https://www.linkedin.com/in/m-umair01/)")
   
    h1 = "This App uses a fine-tuned DistilBERT-Base-Uncased Masked Language Model to predict the missed word in a sentence."
    st.subheader(h1)

    st.write("Its code and other interesting projects are available on my [website](https://mumairab.github.io/)")
    h2 = "Enter your text and put \"[MASK]\" at the word which you want to predict, as shown in the following example: Can we [MASK] to Paris?"
    st.write(h2)

    text = st.text_input(label="Enter your text here:",
                 value="Type here ...")

    if(st.button('Submit')):
        # Perform the input validation
        if "[MASK]" not in text:
            st.write("You did not enter \"[MASK]\" in the text. Please write your text again!")
        else:
            print_the_mask(text)

#Call the main function
if __name__ == "__main__":
    #Launch the Gradio interface
    main()