Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
import numpy as np | |
import streamlit as st | |
from transformers import AutoModelWithLMHead, PreTrainedTokenizerFast | |
model_dir = "snoop2head/kogpt-conditional-2" | |
tokenizer = PreTrainedTokenizerFast.from_pretrained( | |
model_dir, | |
bos_token="<s>", | |
eos_token="</s>", | |
unk_token="<unk>", | |
pad_token="<pad>", | |
mask_token="<mask>", | |
) | |
def load_model(model_name): | |
model = AutoModelWithLMHead.from_pretrained(model_name) | |
return model | |
model = load_model(model_dir) | |
print("loaded model completed") | |
def find_nth(haystack, needle, n): | |
start = haystack.find(needle) | |
while start >= 0 and n > 1: | |
start = haystack.find(needle, start + len(needle)) | |
n -= 1 | |
return start | |
def infer(input_ids, max_length, temperature, top_k, top_p): | |
output_sequences = model.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
do_sample=True, | |
num_return_sequences=1, | |
) | |
return output_sequences | |
# prompts | |
st.title("์ผํ์์ ๋ฌ์ธ KoGPT์ ๋๋ค ๐ฆ") | |
st.write("ํ ์คํธ๋ฅผ ์ ๋ ฅํ๊ณ CTRL+Enter(CMD+Enter)์ ๋๋ฅด์ธ์ ๐ค") | |
# text and sidebars | |
default_value = "๋ฐ์๋ฏผ" | |
sent = st.text_area("Text", default_value, max_chars=4, height=275) | |
max_length = st.sidebar.slider("์์ฑ ๋ฌธ์ฅ ๊ธธ์ด๋ฅผ ์ ํํด์ฃผ์ธ์!", min_value=42, max_value=64) | |
temperature = st.sidebar.slider( | |
"Temperature", value=1.0, min_value=0.0, max_value=1.0, step=0.05 | |
) | |
top_k = st.sidebar.slider("Top-k", min_value=0, max_value=5, value=0) | |
top_p = st.sidebar.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9) | |
print("slider sidebars rendering completed") | |
# make input sentence | |
emotion_list = ["ํ๋ณต", "์ค๋ฆฝ", "๋ถ๋ ธ", "ํ์ค", "๋๋", "์ฌํ", "๊ณตํฌ"] | |
main_emotion = st.sidebar.radio("์ฃผ์ ๊ฐ์ ์ ์ ํํ์ธ์", emotion_list) | |
sub_emotion = st.sidebar.radio("๋ ๋ฒ์งธ ๊ฐ์ ์ ์ ํํ์ธ์", emotion_list) | |
print("radio sidebars rendering completed") | |
# create condition sentence | |
random_main_logit = np.random.normal(loc=3.368, scale=1.015, size=1)[0].round(1) | |
random_sub_logit = np.random.normal(loc=1.333, scale=0.790, size=1)[0].round(1) | |
condition_sentence = f"{random_main_logit}๋งํผ {main_emotion}๊ฐ์ ์ธ ๋ฌธ์ฅ์ด๋ค. {random_sub_logit}๋งํผ {sub_emotion}๊ฐ์ ์ธ ๋ฌธ์ฅ์ด๋ค. " | |
condition_plus_input = condition_sentence + sent | |
print(condition_plus_input) | |
def infer_sentence( | |
condition_plus_input=condition_plus_input, tokenizer=tokenizer, top_k=2 | |
): | |
encoded_prompt = tokenizer.encode( | |
condition_plus_input, add_special_tokens=False, return_tensors="pt" | |
) | |
if encoded_prompt.size()[-1] == 0: | |
input_ids = None | |
else: | |
input_ids = encoded_prompt | |
output_sequences = infer(input_ids, max_length, temperature, top_k, top_p) | |
print(output_sequences) | |
generated_sequence = output_sequences[0] | |
print(generated_sequence) | |
# print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===") | |
# generated_sequences = generated_sequence.tolist() | |
# Decode text | |
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) | |
print(text) | |
# Remove all text after the stop token | |
stop_token = tokenizer.pad_token | |
print(stop_token) | |
text = text[: text.find(stop_token) if stop_token else None] | |
print(text) | |
condition_index = find_nth(text, "๋ฌธ์ฅ์ด๋ค", 2) | |
text = text[condition_index + 5 :] | |
text = text.strip() | |
return text | |
def make_residual_conditional_samhaengshi(input_letter, condition_sentence): | |
# make letter string into | |
list_samhaengshi = [] | |
# initializing text and index for iteration purpose | |
index = 0 | |
# iterating over the input letter string | |
for index, letter_item in enumerate(input_letter): | |
# initializing the input_letter | |
if index == 0: | |
residual_text = letter_item | |
# print('residual_text:', residual_text) | |
# infer and add to the output | |
conditional_input = f"{condition_sentence} {residual_text}" | |
inferred_sentence = infer_sentence(conditional_input, tokenizer) | |
if index != 0: | |
# remove previous sentence from the output | |
print("inferred_sentence:", inferred_sentence) | |
inferred_sentence = inferred_sentence.replace( | |
list_samhaengshi[index - 1], "" | |
).strip() | |
else: | |
pass | |
list_samhaengshi.append(inferred_sentence) | |
# until the end of the input_letter, give the previous residual_text to the next iteration | |
if index < len(input_letter) - 1: | |
residual_sentence = list_samhaengshi[index] | |
next_letter = input_letter[index + 1] | |
residual_text = ( | |
f"{residual_sentence} {next_letter}" # previous sentence + next letter | |
) | |
print("residual_text", residual_text) | |
elif index == len(input_letter) - 1: # end of the input_letter | |
# Concatenate strings in the list without intersection | |
return list_samhaengshi | |
return_text = make_residual_conditional_samhaengshi( | |
input_letter=sent, condition_sentence=condition_sentence | |
) | |
print(return_text) | |
st.write(return_text) | |