Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
WORD_PROBABILITY_THRESHOLD = 0.02
|
2 |
+
#WORD_PROBABILITY_THRESHOLD_ENGLISH = 0.02
|
3 |
+
#WORD_PROBABILITY_THRESHOLD_CHINESE = 0.02
|
4 |
+
TOP_K_WORDS = 10
|
5 |
+
|
6 |
+
ENGLISH_LANG = "English"
|
7 |
+
CHINESE_LANG = "Chinese"
|
8 |
+
|
9 |
+
CHINESE_WORDLIST = ['一定','一样','不得了','主观','从此','便于','俗话','倒霉','候选','充沛','分别','反倒','只好','同情','吹捧','咳嗽','围绕','如意','实行','将近','就职','应该','归还','当面','忘记','急忙','恢复','悲哀','感冒','成长','截至','打架','把握','报告','抱怨','担保','拒绝','拜访','拥护','拳头','拼搏','损坏','接待','握手','揭发','攀登','显示','普遍','未免','欣赏','正式','比如','流浪','涂抹','深刻','演绎','留念','瞻仰','确保','稍微','立刻','精心','结算','罕见','访问','请示','责怪','起初','转达','辅导','过瘾','运动','连忙','适合','遭受','重叠','镇静']
|
10 |
+
|
11 |
+
@st.cache_resource
|
12 |
+
def get_model_chinese():
|
13 |
+
return pipeline("fill-mask", MODEL_NAME_CHINESE, device = device)
|
14 |
+
|
15 |
+
@st.cache_resource
|
16 |
+
def get_model_english():
|
17 |
+
return pipeline("fill-mask", MODEL_NAME_ENGLISH, device = device)
|
18 |
+
|
19 |
+
@st.cache_data
|
20 |
+
def get_wordlist_chinese():
|
21 |
+
return pd.read_csv('wordlist_chinese.csv')
|
22 |
+
|
23 |
+
@st.cache_data
|
24 |
+
def get_wordlist_english():
|
25 |
+
return pd.read_csv('wordlist_english.csv')
|
26 |
+
|
27 |
+
def assess_chinese(word, sentence):
|
28 |
+
print("Assessing English")
|
29 |
+
if sentence.lower().find(word.lower()) == -1:
|
30 |
+
print('Sentence does not contain the word!')
|
31 |
+
return
|
32 |
+
|
33 |
+
text = sentence.replace(word.lower(), "<mask>")
|
34 |
+
|
35 |
+
top_k_prediction = mask_filler_chinese(text, top_k=TOP_K_WORDS)
|
36 |
+
target_word_prediction = mask_filler_chinese(text, targets = word)
|
37 |
+
|
38 |
+
score = target_word_prediction[0]['score']
|
39 |
+
|
40 |
+
# append the original word if its not found in the results
|
41 |
+
top_k_prediction_filtered = [output for output in top_k_prediction if \
|
42 |
+
output['token_str'] == word]
|
43 |
+
if len(top_k_prediction_filtered) == 0:
|
44 |
+
top_k_prediction.extend(target_word_prediction)
|
45 |
+
|
46 |
+
return top_k_prediction, score
|
47 |
+
|
48 |
+
def assess_english(word, sentence):
|
49 |
+
if sentence.lower().find(word.lower()) == -1:
|
50 |
+
raise Exception("Sentence does not contain the target word")
|
51 |
+
|
52 |
+
text = sentence.replace(word.lower(), "<mask>")
|
53 |
+
|
54 |
+
top_k_prediction = mask_filler_english(text, top_k=TOP_K_WORDS)
|
55 |
+
target_word_prediction = mask_filler_english(text, targets = chr(9601)+word)
|
56 |
+
|
57 |
+
score = target_word_prediction[0]['score']
|
58 |
+
|
59 |
+
# append the original word if its not found in the results
|
60 |
+
top_k_prediction_filtered = [output for output in top_k_prediction if \
|
61 |
+
output['token_str'] == word]
|
62 |
+
if len(top_k_prediction_filtered) == 0:
|
63 |
+
top_k_prediction.extend(target_word_prediction)
|
64 |
+
|
65 |
+
return top_k_prediction, score
|
66 |
+
|
67 |
+
def assess_sentence(language, word, sentence):
|
68 |
+
if (language == ENGLISH_LANG):
|
69 |
+
return assess_english(word, sentence)
|
70 |
+
elif (language == CHINESE_LANG):
|
71 |
+
return assess_chinese(word, sentence)
|
72 |
+
|
73 |
+
def get_chinese_word():
|
74 |
+
include = (wordlist_chinese.assess == True) & (wordlist_chinese.Chinese.apply(len) == 2)
|
75 |
+
possible_words = wordlist_chinese[include]
|
76 |
+
word = possible_words.sample(1).iloc[0].Chinese
|
77 |
+
test_words = CHINESE_WORDLIST
|
78 |
+
word = np.random.choice(test_words)
|
79 |
+
return word
|
80 |
+
|
81 |
+
def get_english_word():
|
82 |
+
include = (wordlist_english.assess == True)
|
83 |
+
possible_words = wordlist_english[include]
|
84 |
+
word = possible_words.sample(1).iloc[0].word
|
85 |
+
test_words = ["independent","satisfied","excited"]
|
86 |
+
word = np.random.choice(test_words)
|
87 |
+
return word
|
88 |
+
|
89 |
+
def get_word(language):
|
90 |
+
if (language == ENGLISH_LANG):
|
91 |
+
return get_english_word()
|
92 |
+
elif (language == CHINESE_LANG):
|
93 |
+
return get_chinese_word()
|
94 |
+
|
95 |
+
mask_filler_chinese = get_model_chinese()
|
96 |
+
mask_filler_english = get_model_english()
|
97 |
+
wordlist_chinese = get_wordlist_chinese()
|
98 |
+
wordlist_english = get_wordlist_english()
|
99 |
+
|
100 |
+
def highlight_given_word(row):
|
101 |
+
color = '#ACE5EE' if row.Words == target_word else 'white'
|
102 |
+
return [f'background-color:{color}'] * len(row)
|
103 |
+
|
104 |
+
def get_top_5_results(top_k_prediction):
|
105 |
+
predictions_df = pd.DataFrame(top_k_prediction)
|
106 |
+
predictions_df = predictions_df.drop(columns=["token", "sequence"])
|
107 |
+
predictions_df = predictions_df.rename(columns={"score": "Probability", "token_str": "Words"})
|
108 |
+
|
109 |
+
if (predictions_df[:5].Words == target_word).sum() == 0:
|
110 |
+
print("target word not in top 5")
|
111 |
+
top_5_df = predictions_df[:5]
|
112 |
+
target_word_df = predictions_df[(predictions_df.Words == target_word)]
|
113 |
+
print(target_word_df)
|
114 |
+
top_5_df = pd.concat([top_5_df, target_word_df])
|
115 |
+
|
116 |
+
else:
|
117 |
+
top_5_df = predictions_df[:5]
|
118 |
+
top_5_df['Probability'] = top_5_df['Probability'].apply(lambda x: f"{x:.2%}")
|
119 |
+
|
120 |
+
return top_5_df
|
121 |
+
|
122 |
+
#### Streamlit Page
|
123 |
+
st.title("造句 Auto-marking Demo")
|
124 |
+
language = st.radio("Select your language", (ENGLISH_LANG, CHINESE_LANG))
|
125 |
+
#st.info("You are practising on " + language)
|
126 |
+
|
127 |
+
if 'target_word' not in st.session_state:
|
128 |
+
st.session_state['target_word'] = get_word(language)
|
129 |
+
target_word = st.session_state['target_word']
|
130 |
+
|
131 |
+
st.write("Target word: ", target_word)
|
132 |
+
if st.button("Get new word"):
|
133 |
+
st.session_state['target_word'] = get_word(language)
|
134 |
+
st.experimental_rerun()
|
135 |
+
|
136 |
+
st.subheader("Form your sentence and input below!")
|
137 |
+
sentence = st.text_input('Enter your sentence here', placeholder="Enter your sentence here!")
|
138 |
+
|
139 |
+
if st.button("Grade"):
|
140 |
+
top_k_prediction, score = assess_sentence(language, target_word, sentence)
|
141 |
+
with open('./result01.json', 'w') as outfile:
|
142 |
+
outfile.write(str(top_k_prediction))
|
143 |
+
|
144 |
+
st.write(f"Probability: {score:.2%}")
|
145 |
+
st.write(f"Target probability: {WORD_PROBABILITY_THRESHOLD:.2%}")
|
146 |
+
predictions_df = get_top_5_results(top_k_prediction)
|
147 |
+
df_style = predictions_df.style.apply(highlight_given_word, axis=1)
|
148 |
+
|
149 |
+
if (score >= WORD_PROBABILITY_THRESHOLD):
|
150 |
+
st.balloons()
|
151 |
+
st.success("Yay good job! That's a great sentence 🕺 Practice again with other word", icon="✅")
|
152 |
+
st.table(df_style)
|
153 |
+
else:
|
154 |
+
st.warning("Hmmm.. maybe try again?")
|
155 |
+
st.table(df_style)
|
156 |
+
|