import streamlit as st from streamlit_chat import message import json import torch from torch.utils.data import Dataset import torch.utils.data from models import * from utils import * # Setting page title and header st.set_page_config(page_title="UniLM", page_icon=":robot_face:") st.markdown("

UniLM

", unsafe_allow_html=True) # Initialise session state variables if 'generated' not in st.session_state: st.session_state['generated'] = [] if 'past' not in st.session_state: st.session_state['past'] = [] if 'messages' not in st.session_state: st.session_state['messages'] = [ {"role": "system", "content": "You are a helpful assistant."} ] if 'model_name' not in st.session_state: st.session_state['model_name'] = [] if 'cost' not in st.session_state: st.session_state['cost'] = [] if 'total_tokens' not in st.session_state: st.session_state['total_tokens'] = [] if 'total_cost' not in st.session_state: st.session_state['total_cost'] = 1 # Sidebar - let user choose model, show total cost of current conversation, and let user clear the current conversation st.sidebar.title("Settings") model_name = st.sidebar.selectbox("Model:", ("30M_6.1K","NONE")) counter_placeholder = st.sidebar.empty() clear_button = st.sidebar.button("Clear Conversation", key="clear") # Map model names to OpenAI model IDs if model_name == "30M_6.1K": model = "30M_6.1K" else: model = "gpt-4" # reset everything if clear_button: st.session_state['generated'] = [] st.session_state['past'] = [] st.session_state['messages'] = [ {"role": "system", "content": "You are a helpful assistant."} ] st.session_state['number_tokens'] = [] st.session_state['model_name'] = [] st.session_state['cost'] = [] st.session_state['total_cost'] = 0.0 st.session_state['total_tokens'] = [] def evaluate(transformer, question, question_mask, max_len, word_map): """ Performs Greedy Decoding with a batch size of 1 """ rev_word_map = {v: k for k, v in word_map.items()} transformer.eval() start_token = word_map[''] encoded = transformer.encode(question, question_mask) words = torch.LongTensor([[start_token]]).to(device) for step in range(max_len - 1): size = words.shape[1] target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8) target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0) decoded = transformer.decode(words, target_mask, encoded, question_mask) predictions = transformer.logit(decoded[:, -1]) _, next_word = torch.max(predictions, dim=1) next_word = next_word.item() if next_word == word_map['']: break words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim=1) # (1,step+2) # Construct Sentence if words.dim() == 2: words = words.squeeze(0) words = words.tolist() sen_idx = [w for w in words if w not in {word_map['']}] sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))]) return sentence def remove_punc(string): punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~''' no_punct = "" for char in string: if char not in punctuations: no_punct = no_punct + char # space is also a character return no_punct.lower() if model_name == "30M_6.1K": load_checkpoint = True ckpt_path = 'checkpoint_190.pth.tar' with open('WORDMAP_corpus.json', 'r') as j: word_map = json.load(j) if load_checkpoint: checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) transformer = checkpoint['transformer'] else: load_checkpoint = True ckpt_path = 'checkpoint_190.pth.tar' with open('WORDMAP_corpus.json', 'r') as j: word_map = json.load(j) if load_checkpoint: checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) transformer = checkpoint['transformer'] # generate a response def generate_response(prompt): st.session_state['messages'].append({"role": "user", "content": prompt}) question = remove_punc(prompt) max_len = 153 enc_qus = [word_map.get(word, word_map['']) for word in question.split()] question = torch.LongTensor(enc_qus).to(device).unsqueeze(0) question_mask = (question != 0).to(device).unsqueeze(1).unsqueeze(1) sentence = evaluate(transformer, question, question_mask, int(max_len), word_map) response = sentence st.session_state['messages'].append({"role": "assistant", "content": response}) # print(st.session_state['messages']) total_tokens = "153" prompt_tokens = "153" completion_tokens = "153" return response, total_tokens, prompt_tokens, completion_tokens # container for chat history response_container = st.container() # container for text box container = st.container() with container: with st.form(key='my_form', clear_on_submit=True): user_input = st.text_area("You:", key='input', height=2) submit_button = st.form_submit_button(label='✉') if submit_button and user_input: output, total_tokens, prompt_tokens, completion_tokens = generate_response(user_input) st.session_state['past'].append(user_input) st.session_state['generated'].append(output) st.session_state['model_name'].append(model_name) st.session_state['total_tokens'].append(total_tokens) # from https://openai.com/pricing#language-models if model_name == "30M_6.1K": cost = "1" else: cost = "2" if st.session_state['generated']: with response_container: for i in range(len(st.session_state['generated'])): message(st.session_state["past"][i], is_user=True, key=str(i) + '_user') message(st.session_state["generated"][i], key=str(i))