File size: 2,068 Bytes
94a6c27
 
 
 
 
 
eb7e846
94a6c27
eb7e846
94a6c27
 
b97592a
 
 
94a6c27
eb7e846
94a6c27
 
b97592a
94a6c27
eb7e846
94a6c27
 
 
 
b97592a
94a6c27
 
 
 
eb7e846
94a6c27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da4c6e9
94a6c27
 
 
da4c6e9
94a6c27
 
 
 
 
 
da4c6e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# -*-coding:utf-8-*-
import streamlit as st
# code from https://huggingface.co/kakaobrain/kogpt
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM 


tokenizer = AutoTokenizer.from_pretrained(
  'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b', cache_dir='./model_dir/',
  bos_token='[BOS]', eos_token='[EOS]', unk_token='[UNK]', pad_token='[PAD]', mask_token='[MASK]'
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = AutoModelForCausalLM.from_pretrained(
  'kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b',cache_dir='./model_dir/',
  pad_token_id=tokenizer.eos_token_id,
  torch_dtype=torch.float16, low_cpu_mem_usage=False
).to(device=device, non_blocking=True)
_ = model.eval()

print("Model loading done!")

def gpt(prompt):
  with torch.no_grad():
    tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True)
    gen_tokens = model.generate(tokens, do_sample=True, temperature=0.8, max_length=256)
    generated = tokenizer.batch_decode(gen_tokens)[0]

  return generated
  
    
#prompts
st.title("μ—¬λŸ¬λΆ„λ“€μ˜ λ¬Έμž₯을 μ™„μ„±ν•΄μ€λ‹ˆλ‹€. πŸ€–")
st.markdown("카카였 gpt μ‚¬μš©ν•©λ‹ˆλ‹€.")
st.subheader("λͺ‡κ°€μ§€ 예제: ")
example_1_str = "였늘의 λ‚ μ”¨λŠ” λ„ˆλ¬΄ λˆˆλΆ€μ‹œλ‹€. 내일은 "
example_2_str = "μš°λ¦¬λŠ” 행볡을 μ–Έμ œλ‚˜ κ°ˆλ§ν•˜μ§€λ§Œ 항상 "
example_1 = st.button(example_1_str)
example_2 = st.button(example_2_str)
textbox = st.text_area('μ˜€λŠ˜μ€ 아름닀움을 ν–₯ν•΄ 달리고 ', '',height=100,  max_chars=500 )
button = st.button('생성:')
# output
st.subheader("κ²°κ³Όκ°’: ")
if example_1:
    with st.spinner('In progress.......'):
        output_text = gpt(example_1_str)
    st.markdown("\n"+output_text)
if example_2:
    with st.spinner('In progress.......'):
        output_text = gpt(example_2_str)
    st.markdown("\n"+output_text)
if button:
    with st.spinner('In progress.......'):
        if textbox:
            output_text = gpt(textbox)
        else:
            output_text = " "
    st.markdown("\n" + output_text)