|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig |
|
modelname="gpt2-large" |
|
config = AutoConfig.from_pretrained(modelname) |
|
tokenizer = AutoTokenizer.from_pretrained(modelname) |
|
model = AutoModelForCausalLM.from_pretrained(modelname,config=config) |
|
|
|
|
|
def botsay(user_input): |
|
prompt = "This is a conversation between Human and AI bot. AI's name is ThatGPT." |
|
new_token_id=None |
|
gen_tokens="" |
|
new_token="" |
|
j =6 |
|
length=0 |
|
limit = 128 |
|
thatid=5562 |
|
cont = True |
|
last_apppended = False |
|
cnt=0 |
|
disable_repeat_length= 5 |
|
disable_repeat_count = 2 |
|
tokens=[] |
|
while(cont): |
|
cnt+=1 |
|
prob = 1.0 |
|
input_ids=tokenizer(prompt+user_input+"\nAI:"+gen_tokens,return_tensors="pt").input_ids |
|
length=len(input_ids) |
|
if length >limit: |
|
gen_tokens="⚠️sorry length limit. please reload the browser." |
|
return gen_tokens |
|
outs=model(input_ids=input_ids) |
|
topk = torch.topk(outs.logits.squeeze()[-1,:],k=j+1).indices |
|
if new_token =="that": |
|
that_id = 326 |
|
elif new_token ==" that": |
|
that_id = -1 |
|
elif new_token[-1:] ==" ": |
|
that_id = 5562 |
|
else: |
|
that_id = 326 |
|
|
|
if ("thatGPT" in gen_tokens[-12:]): |
|
that_id = -1 |
|
if last_apppended: |
|
that_id = -1 |
|
if that_id in topk: |
|
new_token_id = that_id |
|
else: |
|
new_token_id = torch.argmax(outs.logits.squeeze()[-1,:]) |
|
new_token=tokenizer.decode(new_token_id) |
|
new_token=tokenizer.decode(new_token_id) |
|
prev_tokens=gen_tokens |
|
gen_tokens+=new_token |
|
if (cnt>10) and (disable_repeat_count<gen_tokens.count(gen_tokens[-disable_repeat_length:])): |
|
gen_tokens=prev_tokens |
|
new_token = tokenizer.decode(topk[torch.randint(5, (1,1)).item()]) |
|
gen_tokens+=new_token |
|
|
|
if new_token_id==50256 or new_token_id==198 or new_token=="<|endoftext|>": |
|
if ("that" not in gen_tokens): |
|
gen_tokens = gen_tokens.replace("\n","").replace(".","") |
|
gen_tokens += " that" |
|
else: |
|
cont = False |
|
return gen_tokens.replace("<br>","").replace("AI:","").replace("\xa0","") |
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
def add_text(history, text): |
|
history = history + [(text, None)] |
|
return history, "" |
|
|
|
|
|
def bot(history): |
|
serial_history="" |
|
for h in history: |
|
serial_history+="\nHuman:"+h[0] |
|
if h[1]==None: |
|
break |
|
serial_history+="\nAI:"+h[1].replace("<br>","") |
|
|
|
response = botsay(serial_history) |
|
history[-1][1] = response |
|
serial_history+="\nAI:"+response |
|
return history |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# ThatGPT - AI always replies with \"that\" -") |
|
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=0.85): |
|
txt = gr.Textbox( |
|
show_label=False, |
|
placeholder="AI always replies with \"that\". It may take more than ten seconds.", |
|
).style(container=False) |
|
|
|
txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then( |
|
bot, chatbot, chatbot |
|
) |
|
|
|
demo.launch() |
|
|