|
|
|
import streamlit as st |
|
from streamlit_chat import message |
|
import json |
|
import torch |
|
from torch.utils.data import Dataset |
|
import torch.utils.data |
|
from models import * |
|
from utils import * |
|
|
|
st.set_page_config(page_title="UniLM", page_icon=":robot_face:") |
|
st.markdown("<h1 style='text-align: center;'>UniLM</h1>", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
if 'generated' not in st.session_state: |
|
st.session_state['generated'] = [] |
|
if 'past' not in st.session_state: |
|
st.session_state['past'] = [] |
|
if 'messages' not in st.session_state: |
|
st.session_state['messages'] = [ |
|
{"role": "system", "content": "You are a helpful assistant."} |
|
] |
|
if 'model_name' not in st.session_state: |
|
st.session_state['model_name'] = [] |
|
if 'cost' not in st.session_state: |
|
st.session_state['cost'] = [] |
|
if 'total_tokens' not in st.session_state: |
|
st.session_state['total_tokens'] = [] |
|
if 'total_cost' not in st.session_state: |
|
st.session_state['total_cost'] = 1 |
|
|
|
|
|
st.sidebar.title("Settings") |
|
model_name = st.sidebar.selectbox("Model:", ("30M_6.1K","NONE")) |
|
counter_placeholder = st.sidebar.empty() |
|
|
|
clear_button = st.sidebar.button("Clear Conversation", key="clear") |
|
|
|
|
|
if model_name == "30M_6.1K": |
|
model = "30M_6.1K" |
|
else: |
|
model = "gpt-4" |
|
|
|
|
|
if clear_button: |
|
st.session_state['generated'] = [] |
|
st.session_state['past'] = [] |
|
st.session_state['messages'] = [ |
|
{"role": "system", "content": "You are a helpful assistant."} |
|
] |
|
st.session_state['number_tokens'] = [] |
|
st.session_state['model_name'] = [] |
|
st.session_state['cost'] = [] |
|
st.session_state['total_cost'] = 0.0 |
|
st.session_state['total_tokens'] = [] |
|
|
|
|
|
|
|
def evaluate(transformer, question, question_mask, max_len, word_map): |
|
""" |
|
Performs Greedy Decoding with a batch size of 1 |
|
""" |
|
rev_word_map = {v: k for k, v in word_map.items()} |
|
transformer.eval() |
|
start_token = word_map['<start>'] |
|
encoded = transformer.encode(question, question_mask) |
|
words = torch.LongTensor([[start_token]]).to(device) |
|
|
|
for step in range(max_len - 1): |
|
size = words.shape[1] |
|
target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8) |
|
target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0) |
|
decoded = transformer.decode(words, target_mask, encoded, question_mask) |
|
predictions = transformer.logit(decoded[:, -1]) |
|
_, next_word = torch.max(predictions, dim=1) |
|
next_word = next_word.item() |
|
if next_word == word_map['<end>']: |
|
break |
|
words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim=1) |
|
|
|
|
|
if words.dim() == 2: |
|
words = words.squeeze(0) |
|
words = words.tolist() |
|
|
|
sen_idx = [w for w in words if w not in {word_map['<start>']}] |
|
sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))]) |
|
|
|
return sentence |
|
def remove_punc(string): |
|
punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~''' |
|
no_punct = "" |
|
for char in string: |
|
if char not in punctuations: |
|
no_punct = no_punct + char |
|
return no_punct.lower() |
|
|
|
if model_name == "30M_6.1K": |
|
load_checkpoint = True |
|
ckpt_path = 'checkpoint_190.pth.tar' |
|
with open('WORDMAP_corpus.json', 'r') as j: |
|
word_map = json.load(j) |
|
if load_checkpoint: |
|
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) |
|
transformer = checkpoint['transformer'] |
|
else: |
|
load_checkpoint = True |
|
ckpt_path = 'checkpoint_190.pth.tar' |
|
with open('WORDMAP_corpus.json', 'r') as j: |
|
word_map = json.load(j) |
|
if load_checkpoint: |
|
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) |
|
transformer = checkpoint['transformer'] |
|
|
|
|
|
|
|
|
|
def generate_response(prompt): |
|
st.session_state['messages'].append({"role": "user", "content": prompt}) |
|
question = remove_punc(prompt) |
|
|
|
max_len = 153 |
|
enc_qus = [word_map.get(word, word_map['<unk>']) for word in question.split()] |
|
question = torch.LongTensor(enc_qus).to(device).unsqueeze(0) |
|
question_mask = (question != 0).to(device).unsqueeze(1).unsqueeze(1) |
|
sentence = evaluate(transformer, question, question_mask, int(max_len), word_map) |
|
|
|
response = sentence |
|
st.session_state['messages'].append({"role": "assistant", "content": response}) |
|
|
|
|
|
total_tokens = "153" |
|
prompt_tokens = "153" |
|
completion_tokens = "153" |
|
return response, total_tokens, prompt_tokens, completion_tokens |
|
|
|
|
|
|
|
response_container = st.container() |
|
|
|
container = st.container() |
|
|
|
with container: |
|
with st.form(key='my_form', clear_on_submit=True): |
|
user_input = st.text_area("You:", key='input', height=2) |
|
submit_button = st.form_submit_button(label='β') |
|
|
|
if submit_button and user_input: |
|
output, total_tokens, prompt_tokens, completion_tokens = generate_response(user_input) |
|
st.session_state['past'].append(user_input) |
|
st.session_state['generated'].append(output) |
|
st.session_state['model_name'].append(model_name) |
|
st.session_state['total_tokens'].append(total_tokens) |
|
|
|
|
|
if model_name == "30M_6.1K": |
|
cost = "1" |
|
else: |
|
cost = "2" |
|
|
|
|
|
|
|
if st.session_state['generated']: |
|
with response_container: |
|
for i in range(len(st.session_state['generated'])): |
|
message(st.session_state["past"][i], is_user=True, key=str(i) + '_user') |
|
message(st.session_state["generated"][i], key=str(i)) |
|
|
|
|