tricksy / app.py
austinsilveria's picture
Add tricksy
75fa479
raw
history blame
No virus
2.15 kB
from threading import Thread
import streamlit as st
import torch
from transformers import AutoTokenizer, TextIteratorStreamer, set_seed
from modeling_tricksy import TricksyOPTForCausalLM, OPTDiskWeights
from configuration_tricksy import TricksyConfig
def generate():
set_seed(42)
# 13.4 GB (16 bit)
model_name = 'facebook/opt-6.7b'
disk_weights = OPTDiskWeights(model_name)
tricksy_model = TricksyOPTForCausalLM(TricksyConfig(disk_weights.config, full_offload=(not use_tricksy)), disk_weights)
tokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
inputs = tokenizer(prompt, return_tensors='pt').input_ids.to('cuda')
print()
generation_kwargs = dict(inputs=inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, top_p=top_p)
thread = Thread(target=tricksy_model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ''
with st.chat_message("user"):
t = st.empty()
for new_text in streamer:
generated_text += new_text.replace('\n', ' \n')
t.write(generated_text)
stats_text = f'Decoding tok/s: {1 / (sum(tricksy_model.tricksy_context.forward_times[1:]) / (len(tricksy_model.tricksy_context.forward_times) - 1))}'
stats_text += f' \nCurrent GPU mem usage: {torch.cuda.memory_allocated("cuda") / 1024 ** 3} GB'
stats_text += f' \nMax GPU mem usage: {torch.cuda.max_memory_allocated("cuda") / 1024 ** 3} GB'
st.write(stats_text)
prompt = st.text_area('Prompt', 'Making pesto from scratch can be done with these ingredients in 4 simple steps:\nStep 1')
col1, col2 = st.columns(2)
with col1:
submit = st.button('Submit', on_click=generate)
with col2:
use_tricksy = st.toggle('Use Tricksy', True, help='If true, only send sparse MLP weight diffs to GPU. If false, send all weights to GPU.')
with st.expander('Additional options'):
max_new_tokens = st.slider('Max new tokens', 1, 500, 100)
top_k = st.slider('Top-k sampling', 1, 500, 50)
top_p = st.slider('Top-p (nucleus sampling)', 0.0, 1.0, .9)