File size: 2,713 Bytes
75fa479
33ad5e9
 
75fa479
b2a1f5e
 
75fa479
 
 
 
 
33ad5e9
 
 
 
75fa479
 
 
 
 
 
33ad5e9
75fa479
33ad5e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from threading import Thread
import gc
import time

import streamlit as st

import torch
from transformers import AutoTokenizer, TextIteratorStreamer, set_seed
from modeling_tricksy import TricksyOPTForCausalLM, OPTDiskWeights
from configuration_tricksy import TricksyConfig

if 'submit' in st.session_state and st.session_state.submit == True:
    st.session_state.generating = True
else:
    st.session_state.generating = False

prompt = st.text_area('Prompt', 'Making pesto from scratch can be done with these ingredients in 4 simple steps:\nStep 1')

col1, col2 = st.columns(2)

with st.expander('Additional options'):
    max_new_tokens = st.slider('Max new tokens', 1, 500, 50)
    top_k = st.slider('Top-k sampling', 1, 500, 50)
    top_p = st.slider('Top-p (nucleus sampling)', 0.0, 1.0, .9)

out = st.chat_message('user')
stats = st.empty()

with col1:
    use_tricksy = st.toggle('Use Tricksy', True, help='If true, only send sparse MLP weight diffs to GPU. If false, send all weights to GPU.')
with col2:
    if st.button('Submit', disabled=st.session_state.generating, key='submit'):
        set_seed(42)
        # 13.4 GB (16 bit)
        model_name = 'facebook/opt-6.7b'
        disk_weights = OPTDiskWeights(model_name)
        tricksy_model = TricksyOPTForCausalLM(TricksyConfig(disk_weights.config, full_offload=(not use_tricksy)), disk_weights)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)

        inputs = tokenizer(prompt, return_tensors='pt').input_ids.to('cuda')

        print()
        generation_kwargs = dict(inputs=inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, top_p=top_p)
        thread = Thread(target=tricksy_model.generate, kwargs=generation_kwargs)
        thread.start()
        generated_text = ''
        with out:
            t = st.empty()
            for new_text in streamer:
                generated_text += new_text.replace('\n', '  \n')
                t.write(generated_text)

        stats_text = f'Decoding tok/s: {1 / (sum(tricksy_model.tricksy_context.forward_times[1:]) / (len(tricksy_model.tricksy_context.forward_times) - 1))}'
        stats_text += f'  \nCurrent GPU mem usage: {torch.cuda.memory_allocated("cuda") / 1024 ** 3} GB'
        stats_text += f'  \nMax GPU mem usage: {torch.cuda.max_memory_allocated("cuda") / 1024 ** 3} GB'
        stats.write(stats_text)

        disk_weights = None
        tricksy_model = None
        time.sleep(.2)
        # st.write(f'num open files: {len(psutil.Process().open_files())}')
        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.reset_peak_memory_stats()