File size: 1,942 Bytes
0d215ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import streamlit as st
import requests
import time

def infer(prompt, max_new_tokens=10, temperature=0.0, top_p=1.0):

    my_post_dict = {
        "type": "general",
        "payload": {
            "max_tokens": max_new_tokens,
            "n": 1,
            "temperature": float(temperature),
            "top_p": float(top_p),
            "model": "Together-gpt-J-6B-ProxAdam-50x",
            "prompt": [prompt],
            "request_type": "language-model-inference",
            "stop": None,
            "best_of": 1,
            "echo": False,
            "seed": 42,
            "prompt_embedding": False,
        },
        "returned_payload": {},
        "status": "submitted",
        "source": "dalle",
    }
    
    res = requests.post("https://planetd.shift.ml/jobs", json=my_post_dict).json()
    
    job_id = res['id']
    
    while True:
        
        ret = requests.get(f"https://planetd.shift.ml/job/{job_id}", json={'id': job_id}).json()
        
        if ret['status'] == 'finished':
            break
  
        time.sleep(1)
        
    return ret['returned_payload']['result']['inference_result'][0]['choices'][0]['text']
    
    
st.title("TOMA Application")
 
s_example = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
prompt = st.text_area(
    "Prompt",
    value=s_example,
    max_chars=1000,
    height=400,
)
    
    
generated_area = st.empty()
generated_area.markdown("(Generate here)")

button_submit = st.button("Submit")
   
max_new_tokens = st.number_input('Max new tokens', 1, 1024, 10)
temperature = st.number_input('temperature', 0.0, 10.0, 0.0, step=0.1, format="%.2f")
top_p = st.number_input('top_p', 0.0, 1.0, 1.0, step=0.1, format="%.2f")

if button_submit:
    with st.spinner(text="In progress.."):
        report_text = infer(prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)
        generated_area.markdown(report_text)