snoop2head's picture
remove error
959ebd2
raw
history blame
5.32 kB
# -*- coding: utf-8 -*-
import numpy as np
import streamlit as st
from transformers import AutoModelWithLMHead, PreTrainedTokenizerFast
model_dir = "snoop2head/kogpt-conditional-2"
tokenizer = PreTrainedTokenizerFast.from_pretrained(
model_dir,
bos_token="<s>",
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
)
@st.cache
def load_model(model_name):
model = AutoModelWithLMHead.from_pretrained(model_name)
return model
model = load_model(model_dir)
print("loaded model completed")
def find_nth(haystack, needle, n):
start = haystack.find(needle)
while start >= 0 and n > 1:
start = haystack.find(needle, start + len(needle))
n -= 1
return start
def infer(input_ids, max_length, temperature, top_k, top_p):
output_sequences = model.generate(
input_ids=input_ids,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
num_return_sequences=1,
)
return output_sequences
# prompts
st.title("์‚ผํ–‰์‹œ์˜ ๋‹ฌ์ธ KoGPT์ž…๋‹ˆ๋‹ค ๐Ÿฆ„")
st.write("ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜๊ณ  CTRL+Enter(CMD+Enter)์„ ๋ˆ„๋ฅด์„ธ์š” ๐Ÿค—")
# text and sidebars
default_value = "๋ฐ•์ˆ˜๋ฏผ"
sent = st.text_area("Text", default_value, max_chars=4, height=275)
max_length = st.sidebar.slider("์ƒ์„ฑ ๋ฌธ์žฅ ๊ธธ์ด๋ฅผ ์„ ํƒํ•ด์ฃผ์„ธ์š”!", min_value=42, max_value=64)
temperature = st.sidebar.slider(
"Temperature", value=1.0, min_value=0.0, max_value=1.0, step=0.05
)
top_k = st.sidebar.slider("Top-k", min_value=0, max_value=5, value=0)
top_p = st.sidebar.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, value=0.9)
print("slider sidebars rendering completed")
# make input sentence
emotion_list = ["ํ–‰๋ณต", "์ค‘๋ฆฝ", "๋ถ„๋…ธ", "ํ˜์˜ค", "๋†€๋žŒ", "์Šฌํ””", "๊ณตํฌ"]
main_emotion = st.sidebar.radio("์ฃผ์š” ๊ฐ์ •์„ ์„ ํƒํ•˜์„ธ์š”", emotion_list)
sub_emotion = st.sidebar.radio("๋‘ ๋ฒˆ์งธ ๊ฐ์ •์„ ์„ ํƒํ•˜์„ธ์š”", emotion_list)
print("radio sidebars rendering completed")
# create condition sentence
random_main_logit = np.random.normal(loc=3.368, scale=1.015, size=1)[0].round(1)
random_sub_logit = np.random.normal(loc=1.333, scale=0.790, size=1)[0].round(1)
condition_sentence = f"{random_main_logit}๋งŒํผ {main_emotion}๊ฐ์ •์ธ ๋ฌธ์žฅ์ด๋‹ค. {random_sub_logit}๋งŒํผ {sub_emotion}๊ฐ์ •์ธ ๋ฌธ์žฅ์ด๋‹ค. "
condition_plus_input = condition_sentence + sent
print(condition_plus_input)
def infer_sentence(
condition_plus_input=condition_plus_input, tokenizer=tokenizer, top_k=2
):
encoded_prompt = tokenizer.encode(
condition_plus_input, add_special_tokens=False, return_tensors="pt"
)
if encoded_prompt.size()[-1] == 0:
input_ids = None
else:
input_ids = encoded_prompt
output_sequences = infer(input_ids, max_length, temperature, top_k, top_p)
print(output_sequences)
generated_sequence = output_sequences[0]
print(generated_sequence)
# print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
# generated_sequences = generated_sequence.tolist()
# Decode text
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
print(text)
# Remove all text after the stop token
stop_token = tokenizer.pad_token
print(stop_token)
text = text[: text.find(stop_token) if stop_token else None]
print(text)
condition_index = find_nth(text, "๋ฌธ์žฅ์ด๋‹ค", 2)
text = text[condition_index + 5 :]
text = text.strip()
return text
def make_residual_conditional_samhaengshi(input_letter, condition_sentence):
# make letter string into
list_samhaengshi = []
# initializing text and index for iteration purpose
index = 0
# iterating over the input letter string
for index, letter_item in enumerate(input_letter):
# initializing the input_letter
if index == 0:
residual_text = letter_item
# print('residual_text:', residual_text)
# infer and add to the output
conditional_input = f"{condition_sentence} {residual_text}"
inferred_sentence = infer_sentence(conditional_input, tokenizer)
if index != 0:
# remove previous sentence from the output
print("inferred_sentence:", inferred_sentence)
inferred_sentence = inferred_sentence.replace(
list_samhaengshi[index - 1], ""
).strip()
else:
pass
list_samhaengshi.append(inferred_sentence)
# until the end of the input_letter, give the previous residual_text to the next iteration
if index < len(input_letter) - 1:
residual_sentence = list_samhaengshi[index]
next_letter = input_letter[index + 1]
residual_text = (
f"{residual_sentence} {next_letter}" # previous sentence + next letter
)
print("residual_text", residual_text)
elif index == len(input_letter) - 1: # end of the input_letter
# Concatenate strings in the list without intersection
return list_samhaengshi
return_text = make_residual_conditional_samhaengshi(
input_letter=sent, condition_sentence=condition_sentence
)
print(return_text)
st.write(return_text)