|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
from typing import * |
|
import torch |
|
import transformers |
|
|
|
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig |
|
|
|
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf") |
|
model = LlamaForCausalLM.from_pretrained( |
|
"decapoda-research/llama-7b-hf", |
|
device_map="cpu", |
|
) |
|
|
|
def evaluate(question): |
|
prompt = f"The conversation between human and AI assistant.\n[|Human|] {question}.\n[|AI|] " |
|
inputs = tokenizer(question, return_tensors="pt") |
|
input_ids = inputs["input_ids"].cuda() |
|
generation_output = model.generate( |
|
input_ids=input_ids, |
|
generation_config=GenerationConfig( |
|
temperature=1, |
|
top_p=0.95, |
|
num_beams=4, |
|
max_context_length_tokens=2048, |
|
), |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
max_new_tokens=512 |
|
) |
|
output = tokenizer.decode(generation_output.sequences[0]).split("[|AI|]")[1] |
|
return output |
|
|
|
|
|
def generate_prompt_with_history(text:str, history: str, tokenizer, max_length=2048): |
|
history = ["\n[|Human|]{}\n[|AI|]{}".format(x[0],x[1]) for x in history] |
|
history.append("\n[|Human|]{}\n[|AI|]".format(text)) |
|
history_text = "" |
|
|
|
for x in history[::-1]: |
|
if tokenizer(history_text + x, return_tensors="pt")['input_ids'].size(-1) <= max_length: |
|
history_text = x + history_text |
|
flag = True |
|
if flag: |
|
return history_text, tokenizer(history_text, return_tensors="pt") |
|
else: |
|
return False |
|
|
|
|
|
def is_stop_word_or_prefix(s: str, stop_words: list) -> bool: |
|
for stop_word in stop_words: |
|
if s.endswith(stop_word): |
|
return True |
|
for i in range(1, len(stop_word)): |
|
if s.endswith(stop_word[:i]): |
|
return True |
|
return False |
|
|
|
|
|
def greedy_search(input_ids: torch.Tensor, |
|
model: torch.nn.Module, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
stop_words: list, |
|
max_length: int, |
|
temperature: float = 1.0, |
|
top_p: float = 1.0, |
|
top_k: int = 25) -> Iterator[str]: |
|
generated_tokens = [] |
|
past_key_values = None |
|
current_length = 1 |
|
for i in range(max_length): |
|
with torch.no_grad(): |
|
if past_key_values is None: |
|
outputs = model(input_ids) |
|
else: |
|
outputs = model(input_ids[:, -1:], past_key_values=past_key_values) |
|
logits = outputs.logits[:, -1, :] |
|
past_key_values = outputs.past_key_values |
|
|
|
logits /= temperature |
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) |
|
probs_sum = torch.cumsum(probs_sort, dim=-1) |
|
mask = probs_sum - probs_sort > top_p |
|
probs_sort[mask] = 0.0 |
|
|
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) |
|
next_token = torch.multinomial(probs_sort, num_samples=1) |
|
next_token = torch.gather(probs_idx, -1, next_token) |
|
|
|
input_ids = torch.cat((input_ids, next_token), dim=-1) |
|
|
|
generated_tokens.append(next_token[0].item()) |
|
text = tokenizer.decode(generated_tokens) |
|
|
|
yield text |
|
if any([x in text for x in stop_words]): |
|
return |
|
@torch.no_grad() |
|
|
|
|
|
def predict(text:str, |
|
chatbot, |
|
history:str = "", |
|
top_p:float = 0.95, |
|
temperature:float = 1.0, |
|
max_length_tokens:int = 512, |
|
max_context_length_tokens:int = 2048): |
|
if text=="": |
|
return "" |
|
|
|
inputs = generate_prompt_with_history(text, history, tokenizer, max_length=max_context_length_tokens) |
|
prompt,inputs=inputs |
|
begin_length = len(prompt) |
|
|
|
input_ids = inputs["input_ids"].to(chatbot.device) |
|
output = [] |
|
|
|
for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p): |
|
if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False: |
|
if "[|Human|]" in x: |
|
x = x[:x.index("[|Human|]")].strip() |
|
elif "[| Human |]" in x: |
|
x = x[:x.index("[| Human |]")].strip() |
|
if "[|AI|]" in x: |
|
x = x[:x.index("[|AI|]")].strip() |
|
x = x.strip(" ") |
|
output.append(x) |
|
return output[-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iface = gr.Interface(fn = predict, |
|
inputs = "text", |
|
outputs = ["text"], |
|
title = "Learn with ChadGPT", |
|
description = "Ciao!!!") |
|
|
|
iface.launch(inline = False) |