File size: 5,451 Bytes
5e48162
0d215ca
 
 
e532db6
0d215ca
c694998
 
f203ba6
5fcdcad
7e6795e
 
 
 
 
 
 
 
 
 
 
3442116
 
 
7e6795e
 
 
 
 
 
0d215ca
7e6795e
 
 
 
0b3be54
7e6795e
f79758c
9cecd9c
f203ba6
5e48162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f203ba6
5e48162
 
 
 
 
 
 
 
 
 
 
7e6795e
 
5fcdcad
7e6795e
5e48162
5fcdcad
7e6795e
 
 
5fcdcad
7e6795e
 
0d215ca
5fcdcad
 
7e6795e
 
 
 
 
 
 
 
 
 
 
5fcdcad
0d215ca
7e6795e
0d215ca
5fcdcad
 
7e6795e
 
5fcdcad
5e48162
5fcdcad
 
0d215ca
7e6795e
 
0d215ca
a531b86
7e6795e
 
 
 
 
 
5fcdcad
7e6795e
 
5e48162
a531b86
7e6795e
 
 
 
 
 
5fcdcad
 
 
 
 
7e6795e
 
 
 
 
 
 
 
5fcdcad
7e6795e
 
 
 
 
 
 
 
 
 
 
5fcdcad
 
 
 
 
 
 
7e6795e
5fcdcad
 
7e6795e
 
 
5fcdcad
7e6795e
 
 
c694998
7e6795e
 
 
 
c694998
7e6795e
df71374
7e6795e
 
c694998
7e6795e
 
5e48162
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186

import streamlit as st
import requests
import time
from ast import literal_eval

def to_md(text):
    return text.replace("\n", "<br />")

@st.cache
def infer(
    prompt, 
    model_name, 
    max_new_tokens=10, 
    temperature=0.0, 
    top_p=1.0,
    top_k=40,
    num_completions=1,
    seed=42,
    stop="\n"
):
    model_name_map = {
        "GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
    }
    max_new_tokens = int(max_new_tokens)
    num_completions = int(num_completions)
    temperature = float(temperature)
    top_p = float(top_p)
    stop = stop.split(";")
    seed = seed
    
    assert 0 <= max_new_tokens <= 256
    assert 1 <= num_completions <= 5
    assert 0.0 <= temperature <= 10.0
    assert 0.0 <= top_p <= 1.0
    
    if temperature == 0.0:
        temperature = 0.01

    my_post_dict = {
        "type": "general",
        "payload": {
            "max_tokens": int(max_new_tokens),
            "n": int(num_completions),
            "temperature": float(temperature),
            "top_p": float(top_p),
            "model": model_name_map[model_name],
            "prompt": [prompt],
            "request_type": "language-model-inference",
            "stop": stop,
            "best_of": 1,
            "echo": False,
            "seed": int(seed),
            "prompt_embedding": False,
        },
        "returned_payload": {},
        "status": "submitted",
        "source": "dalle",
    }
    
    job_id = requests.post("https://planetd.shift.ml/jobs", json=my_post_dict).json()['id']
    
    for i in range(100):
    
        time.sleep(0.5)
        ret = requests.get(f"https://planetd.shift.ml/job/{job_id}", json={'id': job_id}).json()
        if ret['status'] == 'finished':
            break

    generated_text = ret['returned_payload']['result']['inference_result'][0]['choices'][0]['text']
    
    for stop_word in stop:
        if stop_word in generated_text:
            generated_text = generated_text[:generated_text.find(stop_word)]

    st.session_state.updated = True
    
    return generated_text


def set_preset():
    if st.session_state.preset == "Classification":
        
        if not st.session_state.updated:
            st.session_state.prompt = '''Please classify the given sentence.
Possible labels:
1. <label_0>
2. <label_1>

Input: <sentence_0>
Label: <label_0>

Input: <sentence_1>
Label:'''
        st.session_state.temperature = "0.0"
        st.session_state.top_p = "1.0"
        st.session_state.max_new_tokens = "10"
        
    elif st.session_state.preset == "Generation":
        
        if not st.session_state.updated:
            st.session_state.prompt = '''Please write a story given keywords.

Input: bear, honey
Story: Once upon a time,'''
        st.session_state.temperature = "0.0"
        st.session_state.top_p = "0.9"
        st.session_state.max_new_tokens = "100"
    
    else:
        pass
    
    
def main():

    if 'preset' not in st.session_state:
        st.session_state.preset = "Classification"

    if 'prompt' not in st.session_state:
        st.session_state.prompt = "Please answer the following question:\n\nQuestion: In which country is Zurich located?\nAnswer:"

    if 'temperature' not in st.session_state:
        st.session_state.temperature = "0.0"

    if 'top_p' not in st.session_state:
        st.session_state.top_p = "1.0"

    if 'top_k' not in st.session_state:
        st.session_state.top_k = "40"
        
    if 'max_new_tokens' not in st.session_state:
        st.session_state.max_new_tokens = "10"
        
    if 'updated' not in st.session_state:
        st.session_state.updated = False
        

    st.title("GPT-JT")

    col1, col2 = st.columns([1, 3])

    with col1:
        model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
        max_new_tokens = st.text_input('Max new tokens', st.session_state.max_new_tokens)
        temperature = st.text_input('temperature', st.session_state.temperature)
        top_k = st.text_input('top_k', st.session_state.top_k)
        top_p = st.text_input('top_p', st.session_state.top_p)
        # num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
        num_completions = "1"
        stop = st.text_input('stop, split by;', r'\n')
        # seed = st.text_input('seed', "42")
        seed = "42"

    with col2:

        # preset = st.radio(
        #     "Recommended Templates", 
        #     ('Classification', 'Generation'), 
        #     on_change=set_preset,
        #     key="preset",
        #     horizontal=True
        # )

        prompt_area = st.empty()
        prompt = prompt_area.text_area(
            "Prompt",
            value=st.session_state.prompt,
            max_chars=4096,
            height=300,
        )

        generated_area = st.empty()
        generated_area.markdown("(Generate here)")

        button_submit = st.button("Submit")

        if button_submit:
            generated_area.markdown("<b>" + to_md(prompt) + "</b>", unsafe_allow_html=True)
            report_text = infer(
                prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k,
                num_completions=num_completions, seed=seed, stop=literal_eval("'''"+stop+"'''"),
            )
            generated_area.markdown("<b>" + to_md(prompt) + "</b>" + to_md(report_text), unsafe_allow_html=True)
        
if __name__ == '__main__':
    main()