Spaces:
Runtime error
Runtime error
import streamlit as st | |
def get_pipe(): | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
model_name = "heegyu/koalpaca-355m" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.truncation_side = "right" | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
return model, tokenizer | |
def get_response(tokenizer, model, context): | |
context = f"<usr>{context}\n<sys>" | |
inputs = tokenizer( | |
context, | |
truncation=True, | |
max_length=512, | |
return_tensors="pt") | |
generation_args = dict( | |
max_length=256, | |
min_length=64, | |
eos_token_id=2, | |
do_sample=True, | |
top_p=1.0, | |
early_stopping=True | |
) | |
outputs = model.generate(**inputs, **generation_args) | |
response = tokenizer.decode(outputs[0]) | |
print(context) | |
print(response) | |
response = response[len(context):].replace("</s>", "") | |
return response | |
st.title("KoAlpaca-355M") | |
with st.spinner("loading model..."): | |
model, tokenizer = get_pipe() | |
input_ = st.text_area("질문해보세요", value="미국과 중국의 갈등의 원인이 뭐야?") | |
ok = st.button("물어보기") | |
if input_ is not None and ok and len(input_) > 0: | |
with st.spinner("잠시만요"): | |
response = get_response(tokenizer, model, input_) | |
st.text("대답") | |
st.success(response) |