Spaces:
Runtime error
Runtime error
File size: 7,144 Bytes
97c311c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import streamlit as st
from huggingface_hub import snapshot_download
import os # utility library
# libraries to load the model and serve inference
import tensorflow_text
import tensorflow as tf
def main():
st.title("Interactive demo: T5 Multitasking Demo")
st.sidebar.image("https://i.gzn.jp/img/2020/02/25/google-ai-t5/01.png")
saved_model_path = load_model_cache()
# Model is loaded in st.session_state to remain stateless across reloading
if 'model' not in st.session_state:
st.session_state.model = tf.saved_model.load(saved_model_path, ["serve"])
dashboard(st.session_state.model)
@st.cache
def load_model_cache():
"""Function to retrieve the model from HuggingFace Hub and cache it using st.cache wrapper
"""
CACHE_DIR = "hfhub_cache" # where the library's fork would be stored once downloaded
if not os.path.exists(CACHE_DIR):
os.mkdir(CACHE_DIR)
# download the files from huggingface repo and load the model with tensorflow
snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR)
saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0])
return saved_model_path
def dashboard(model):
"""Function to display the inputs and results
params:
model stateless model to run inference from
"""
task_type = st.sidebar.radio("Task Type",
[
"Translate English to French",
"Translate English to German",
"Translate English to Romanian",
"Grammatical Correctness of Sentence",
"Text Summarization",
"Document Similarity Score"
])
default_sentence = "I am Steven and I live in Lagos, Nigeria."
text_summarization_sentence = "I don't care about those doing the comparison, but comparing \
the Ghanaian Jollof Rice to Nigerian Jollof Rice is an insult to Nigerians."
doc_similarity_sentence1 = "I reside in the commercial capital city of Nigeria, which is Lagos."
doc_similarity_sentence2 = "I live in Lagos."
help_msg = "You could either type in the sentences to run inferences on or use the upload button to \
upload text files containing those sentences. The input sentence box, by default, displays sample \
texts or the texts in the files that you've uploaded. Feel free to erase them and type in new sentences."
if task_type.startswith("Document Similarity"): # document similarity requires two documents
uploaded_file = upload_files(help_msg, text="Upload 2 documents for similarity check", accept_multiple_files=True)
if uploaded_file:
sentence1 = st.text_area("Enter first document/sentence", uploaded_file[0], help=help_msg)
sentence2 = st.text_area("Enter second document/sentence", uploaded_file[1], help=help_msg)
else:
sentence1 = st.text_area("Enter first document/sentence", doc_similarity_sentence1)
sentence2 = st.text_area("Enter second document/sentence", doc_similarity_sentence2)
sentence = sentence1 + "---" + sentence2 # to be processed like other tasks' single sentences
else:
uploaded_file = upload_files(help_msg)
if uploaded_file:
sentence = st.text_area("Enter sentence", uploaded_file, help=help_msg)
elif task_type.startswith("Text Summarization"): # text summarization's default input should be longer
sentence = st.text_area("Enter sentence", text_summarization_sentence, help=help_msg)
else:
sentence = st.text_area("Enter sentence", default_sentence, help=help_msg)
st.write("**Output Text**")
with st.spinner("Waiting for prediction..."): # spinner while model is running inferences
output_text = predict(task_type, sentence, model)
st.write(output_text)
try: # to workaround the environment's Streamlit version
st.download_button("Download output text", output_text)
except AttributeError:
st.text("File download not enabled for this Streamlit version \U0001F612")
def upload_files(help_msg, text="Upload a text file here", accept_multiple_files=False):
"""Function to upload text files and return as string text
params:
text Display label for the upload button
accept_multiple_files params for the file_uploader function to accept more than a file
returns:
a string or a list of strings (in case of multiple files being uploaded)
"""
def upload():
uploaded_files = st.file_uploader(label="Upload text files only",
type="txt", help=help_msg,
accept_multiple_files=accept_multiple_files)
if st.button("Process"):
if not uploaded_files:
st.write("**No file uploaded!**")
return None
st.write("**Upload successful!**")
if type(uploaded_files) == list:
return [f.read().decode("utf-8") for f in uploaded_files]
return uploaded_files.read().decode("utf-8")
try: # to workaround the environment's Streamlit version
with st.expander(text):
return upload()
except AttributeError:
return upload()
def predict(task_type, sentence, model):
"""Function to parse the user inputs, run the parsed text through the
model and return output in a readable format.
params:
task_type sentence representing the type of task to run on T5 model
sentence sentence to get inference on
model model to get inferences from
returns:
text decoded into a human-readable format.
"""
task_dict = {
"Translate English to French": "Translate English to French",
"Translate English to German": "Translate English to German",
"Translate English to Romanian": "Translate English to Romanian",
"Grammatical Correctness of Sentence": "cola sentence",
"Text Summarization": "summarize",
"Document Similarity Score": "stsb",
}
question = f"{task_dict[task_type]}: {sentence}" # parsing the user inputs into a format recognized by T5
# Document Similarity takes in two sentences so it has to be parsed in a separate manner
if task_type.startswith("Document Similarity"):
sentences = sentence.split('---')
question = f"{task_dict[task_type]} sentence1: {sentences[0]} sentence2: {sentences[1]}"
return predict_fn([question], model)[0].decode('utf-8')
def predict_fn(x, model):
"""Function to get inferences from model on live data points.
params:
x input text to run get output on
model model to run inferences from
returns:
a numpy array representing the output
"""
return model.signatures['serving_default'](tf.constant(x))['outputs'].numpy()
if __name__ == "__main__":
main()
|