import streamlit as st
from streamlit_chat import message
import json
import torch
from torch.utils.data import Dataset
import torch.utils.data
from models import *
from utils import *
# Setting page title and header
st.set_page_config(page_title="UniLM", page_icon=":robot_face:")
st.markdown("
UniLM
", unsafe_allow_html=True)
# Initialise session state variables
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
if 'messages' not in st.session_state:
st.session_state['messages'] = [
{"role": "system", "content": "You are a helpful assistant."}
]
if 'model_name' not in st.session_state:
st.session_state['model_name'] = []
if 'cost' not in st.session_state:
st.session_state['cost'] = []
if 'total_tokens' not in st.session_state:
st.session_state['total_tokens'] = []
if 'total_cost' not in st.session_state:
st.session_state['total_cost'] = 1
# Sidebar - let user choose model, show total cost of current conversation, and let user clear the current conversation
st.sidebar.title("Settings")
model_name = st.sidebar.selectbox("Model:", ("30M_6.1K","NONE"))
counter_placeholder = st.sidebar.empty()
clear_button = st.sidebar.button("Clear Conversation", key="clear")
# Map model names to OpenAI model IDs
if model_name == "30M_6.1K":
model = "30M_6.1K"
else:
model = "gpt-4"
# reset everything
if clear_button:
st.session_state['generated'] = []
st.session_state['past'] = []
st.session_state['messages'] = [
{"role": "system", "content": "You are a helpful assistant."}
]
st.session_state['number_tokens'] = []
st.session_state['model_name'] = []
st.session_state['cost'] = []
st.session_state['total_cost'] = 0.0
st.session_state['total_tokens'] = []
def evaluate(transformer, question, question_mask, max_len, word_map):
"""
Performs Greedy Decoding with a batch size of 1
"""
rev_word_map = {v: k for k, v in word_map.items()}
transformer.eval()
start_token = word_map['']
encoded = transformer.encode(question, question_mask)
words = torch.LongTensor([[start_token]]).to(device)
for step in range(max_len - 1):
size = words.shape[1]
target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0)
decoded = transformer.decode(words, target_mask, encoded, question_mask)
predictions = transformer.logit(decoded[:, -1])
_, next_word = torch.max(predictions, dim=1)
next_word = next_word.item()
if next_word == word_map['']:
break
words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim=1) # (1,step+2)
# Construct Sentence
if words.dim() == 2:
words = words.squeeze(0)
words = words.tolist()
sen_idx = [w for w in words if w not in {word_map['']}]
sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))])
return sentence
def remove_punc(string):
punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
no_punct = ""
for char in string:
if char not in punctuations:
no_punct = no_punct + char # space is also a character
return no_punct.lower()
if model_name == "30M_6.1K":
load_checkpoint = True
ckpt_path = 'checkpoint_190.pth.tar'
with open('WORDMAP_corpus.json', 'r') as j:
word_map = json.load(j)
if load_checkpoint:
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
transformer = checkpoint['transformer']
else:
load_checkpoint = True
ckpt_path = 'checkpoint_190.pth.tar'
with open('WORDMAP_corpus.json', 'r') as j:
word_map = json.load(j)
if load_checkpoint:
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
transformer = checkpoint['transformer']
# generate a response
def generate_response(prompt):
st.session_state['messages'].append({"role": "user", "content": prompt})
question = remove_punc(prompt)
max_len = 153
enc_qus = [word_map.get(word, word_map['']) for word in question.split()]
question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
question_mask = (question != 0).to(device).unsqueeze(1).unsqueeze(1)
sentence = evaluate(transformer, question, question_mask, int(max_len), word_map)
response = sentence
st.session_state['messages'].append({"role": "assistant", "content": response})
# print(st.session_state['messages'])
total_tokens = "153"
prompt_tokens = "153"
completion_tokens = "153"
return response, total_tokens, prompt_tokens, completion_tokens
# container for chat history
response_container = st.container()
# container for text box
container = st.container()
with container:
with st.form(key='my_form', clear_on_submit=True):
user_input = st.text_area("You:", key='input', height=2)
submit_button = st.form_submit_button(label='✉')
if submit_button and user_input:
output, total_tokens, prompt_tokens, completion_tokens = generate_response(user_input)
st.session_state['past'].append(user_input)
st.session_state['generated'].append(output)
st.session_state['model_name'].append(model_name)
st.session_state['total_tokens'].append(total_tokens)
# from https://openai.com/pricing#language-models
if model_name == "30M_6.1K":
cost = "1"
else:
cost = "2"
if st.session_state['generated']:
with response_container:
for i in range(len(st.session_state['generated'])):
message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')
message(st.session_state["generated"][i], key=str(i))