tricksy / app.py
austinsilveria's picture
take out the trash
33ad5e9
raw
history blame
No virus
2.71 kB
from threading import Thread
import gc
import time
import streamlit as st
import torch
from transformers import AutoTokenizer, TextIteratorStreamer, set_seed
from modeling_tricksy import TricksyOPTForCausalLM, OPTDiskWeights
from configuration_tricksy import TricksyConfig
if 'submit' in st.session_state and st.session_state.submit == True:
st.session_state.generating = True
else:
st.session_state.generating = False
prompt = st.text_area('Prompt', 'Making pesto from scratch can be done with these ingredients in 4 simple steps:\nStep 1')
col1, col2 = st.columns(2)
with st.expander('Additional options'):
max_new_tokens = st.slider('Max new tokens', 1, 500, 50)
top_k = st.slider('Top-k sampling', 1, 500, 50)
top_p = st.slider('Top-p (nucleus sampling)', 0.0, 1.0, .9)
out = st.chat_message('user')
stats = st.empty()
with col1:
use_tricksy = st.toggle('Use Tricksy', True, help='If true, only send sparse MLP weight diffs to GPU. If false, send all weights to GPU.')
with col2:
if st.button('Submit', disabled=st.session_state.generating, key='submit'):
set_seed(42)
# 13.4 GB (16 bit)
model_name = 'facebook/opt-6.7b'
disk_weights = OPTDiskWeights(model_name)
tricksy_model = TricksyOPTForCausalLM(TricksyConfig(disk_weights.config, full_offload=(not use_tricksy)), disk_weights)
tokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
inputs = tokenizer(prompt, return_tensors='pt').input_ids.to('cuda')
print()
generation_kwargs = dict(inputs=inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, top_p=top_p)
thread = Thread(target=tricksy_model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ''
with out:
t = st.empty()
for new_text in streamer:
generated_text += new_text.replace('\n', ' \n')
t.write(generated_text)
stats_text = f'Decoding tok/s: {1 / (sum(tricksy_model.tricksy_context.forward_times[1:]) / (len(tricksy_model.tricksy_context.forward_times) - 1))}'
stats_text += f' \nCurrent GPU mem usage: {torch.cuda.memory_allocated("cuda") / 1024 ** 3} GB'
stats_text += f' \nMax GPU mem usage: {torch.cuda.max_memory_allocated("cuda") / 1024 ** 3} GB'
stats.write(stats_text)
disk_weights = None
tricksy_model = None
time.sleep(.2)
# st.write(f'num open files: {len(psutil.Process().open_files())}')
torch.cuda.empty_cache()
gc.collect()
torch.cuda.reset_peak_memory_stats()