Spaces:
Runtime error
Runtime error
inclusive-ml
commited on
Commit
•
97c311c
1
Parent(s):
4ee896d
initial commit
Browse files- app.py +132 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from huggingface_hub import snapshot_download
|
3 |
+
import os # utility library
|
4 |
+
# libraries to load the model and serve inference
|
5 |
+
import tensorflow_text
|
6 |
+
import tensorflow as tf
|
7 |
+
def main():
|
8 |
+
st.title("Interactive demo: T5 Multitasking Demo")
|
9 |
+
st.sidebar.image("https://i.gzn.jp/img/2020/02/25/google-ai-t5/01.png")
|
10 |
+
saved_model_path = load_model_cache()
|
11 |
+
# Model is loaded in st.session_state to remain stateless across reloading
|
12 |
+
if 'model' not in st.session_state:
|
13 |
+
st.session_state.model = tf.saved_model.load(saved_model_path, ["serve"])
|
14 |
+
dashboard(st.session_state.model)
|
15 |
+
@st.cache
|
16 |
+
def load_model_cache():
|
17 |
+
"""Function to retrieve the model from HuggingFace Hub and cache it using st.cache wrapper
|
18 |
+
"""
|
19 |
+
CACHE_DIR = "hfhub_cache" # where the library's fork would be stored once downloaded
|
20 |
+
if not os.path.exists(CACHE_DIR):
|
21 |
+
os.mkdir(CACHE_DIR)
|
22 |
+
# download the files from huggingface repo and load the model with tensorflow
|
23 |
+
snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR)
|
24 |
+
saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0])
|
25 |
+
return saved_model_path
|
26 |
+
def dashboard(model):
|
27 |
+
"""Function to display the inputs and results
|
28 |
+
params:
|
29 |
+
model stateless model to run inference from
|
30 |
+
"""
|
31 |
+
task_type = st.sidebar.radio("Task Type",
|
32 |
+
[
|
33 |
+
"Translate English to French",
|
34 |
+
"Translate English to German",
|
35 |
+
"Translate English to Romanian",
|
36 |
+
"Grammatical Correctness of Sentence",
|
37 |
+
"Text Summarization",
|
38 |
+
"Document Similarity Score"
|
39 |
+
])
|
40 |
+
default_sentence = "I am Steven and I live in Lagos, Nigeria."
|
41 |
+
text_summarization_sentence = "I don't care about those doing the comparison, but comparing \
|
42 |
+
the Ghanaian Jollof Rice to Nigerian Jollof Rice is an insult to Nigerians."
|
43 |
+
doc_similarity_sentence1 = "I reside in the commercial capital city of Nigeria, which is Lagos."
|
44 |
+
doc_similarity_sentence2 = "I live in Lagos."
|
45 |
+
help_msg = "You could either type in the sentences to run inferences on or use the upload button to \
|
46 |
+
upload text files containing those sentences. The input sentence box, by default, displays sample \
|
47 |
+
texts or the texts in the files that you've uploaded. Feel free to erase them and type in new sentences."
|
48 |
+
if task_type.startswith("Document Similarity"): # document similarity requires two documents
|
49 |
+
uploaded_file = upload_files(help_msg, text="Upload 2 documents for similarity check", accept_multiple_files=True)
|
50 |
+
if uploaded_file:
|
51 |
+
sentence1 = st.text_area("Enter first document/sentence", uploaded_file[0], help=help_msg)
|
52 |
+
sentence2 = st.text_area("Enter second document/sentence", uploaded_file[1], help=help_msg)
|
53 |
+
else:
|
54 |
+
sentence1 = st.text_area("Enter first document/sentence", doc_similarity_sentence1)
|
55 |
+
sentence2 = st.text_area("Enter second document/sentence", doc_similarity_sentence2)
|
56 |
+
sentence = sentence1 + "---" + sentence2 # to be processed like other tasks' single sentences
|
57 |
+
else:
|
58 |
+
uploaded_file = upload_files(help_msg)
|
59 |
+
if uploaded_file:
|
60 |
+
sentence = st.text_area("Enter sentence", uploaded_file, help=help_msg)
|
61 |
+
elif task_type.startswith("Text Summarization"): # text summarization's default input should be longer
|
62 |
+
sentence = st.text_area("Enter sentence", text_summarization_sentence, help=help_msg)
|
63 |
+
else:
|
64 |
+
sentence = st.text_area("Enter sentence", default_sentence, help=help_msg)
|
65 |
+
st.write("**Output Text**")
|
66 |
+
with st.spinner("Waiting for prediction..."): # spinner while model is running inferences
|
67 |
+
output_text = predict(task_type, sentence, model)
|
68 |
+
st.write(output_text)
|
69 |
+
try: # to workaround the environment's Streamlit version
|
70 |
+
st.download_button("Download output text", output_text)
|
71 |
+
except AttributeError:
|
72 |
+
st.text("File download not enabled for this Streamlit version \U0001F612")
|
73 |
+
def upload_files(help_msg, text="Upload a text file here", accept_multiple_files=False):
|
74 |
+
"""Function to upload text files and return as string text
|
75 |
+
params:
|
76 |
+
text Display label for the upload button
|
77 |
+
accept_multiple_files params for the file_uploader function to accept more than a file
|
78 |
+
returns:
|
79 |
+
a string or a list of strings (in case of multiple files being uploaded)
|
80 |
+
"""
|
81 |
+
def upload():
|
82 |
+
uploaded_files = st.file_uploader(label="Upload text files only",
|
83 |
+
type="txt", help=help_msg,
|
84 |
+
accept_multiple_files=accept_multiple_files)
|
85 |
+
if st.button("Process"):
|
86 |
+
if not uploaded_files:
|
87 |
+
st.write("**No file uploaded!**")
|
88 |
+
return None
|
89 |
+
st.write("**Upload successful!**")
|
90 |
+
if type(uploaded_files) == list:
|
91 |
+
return [f.read().decode("utf-8") for f in uploaded_files]
|
92 |
+
return uploaded_files.read().decode("utf-8")
|
93 |
+
try: # to workaround the environment's Streamlit version
|
94 |
+
with st.expander(text):
|
95 |
+
return upload()
|
96 |
+
except AttributeError:
|
97 |
+
return upload()
|
98 |
+
def predict(task_type, sentence, model):
|
99 |
+
"""Function to parse the user inputs, run the parsed text through the
|
100 |
+
model and return output in a readable format.
|
101 |
+
params:
|
102 |
+
task_type sentence representing the type of task to run on T5 model
|
103 |
+
sentence sentence to get inference on
|
104 |
+
model model to get inferences from
|
105 |
+
returns:
|
106 |
+
text decoded into a human-readable format.
|
107 |
+
"""
|
108 |
+
task_dict = {
|
109 |
+
"Translate English to French": "Translate English to French",
|
110 |
+
"Translate English to German": "Translate English to German",
|
111 |
+
"Translate English to Romanian": "Translate English to Romanian",
|
112 |
+
"Grammatical Correctness of Sentence": "cola sentence",
|
113 |
+
"Text Summarization": "summarize",
|
114 |
+
"Document Similarity Score": "stsb",
|
115 |
+
}
|
116 |
+
question = f"{task_dict[task_type]}: {sentence}" # parsing the user inputs into a format recognized by T5
|
117 |
+
# Document Similarity takes in two sentences so it has to be parsed in a separate manner
|
118 |
+
if task_type.startswith("Document Similarity"):
|
119 |
+
sentences = sentence.split('---')
|
120 |
+
question = f"{task_dict[task_type]} sentence1: {sentences[0]} sentence2: {sentences[1]}"
|
121 |
+
return predict_fn([question], model)[0].decode('utf-8')
|
122 |
+
def predict_fn(x, model):
|
123 |
+
"""Function to get inferences from model on live data points.
|
124 |
+
params:
|
125 |
+
x input text to run get output on
|
126 |
+
model model to run inferences from
|
127 |
+
returns:
|
128 |
+
a numpy array representing the output
|
129 |
+
"""
|
130 |
+
return model.signatures['serving_default'](tf.constant(x))['outputs'].numpy()
|
131 |
+
if __name__ == "__main__":
|
132 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
t5
|
2 |
+
huggingface_hub
|
3 |
+
streamlit==1.0.0
|