|
import gradio as gr |
|
import random |
|
import time, os |
|
import copy |
|
import re |
|
|
|
import torch |
|
from rich.console import Console |
|
from rich.table import Table |
|
from datetime import datetime |
|
|
|
from threading import Thread |
|
from typing import Optional |
|
from transformers import TextIteratorStreamer |
|
|
|
from utils.special_tok_llama2 import ( |
|
B_CODE, |
|
E_CODE, |
|
B_RESULT, |
|
E_RESULT, |
|
B_INST, |
|
E_INST, |
|
B_SYS, |
|
E_SYS, |
|
DEFAULT_PAD_TOKEN, |
|
DEFAULT_BOS_TOKEN, |
|
DEFAULT_EOS_TOKEN, |
|
DEFAULT_UNK_TOKEN, |
|
IGNORE_INDEX, |
|
) |
|
|
|
from finetuning.conversation_template import ( |
|
json_to_code_result_tok_temp, |
|
msg_to_code_result_tok_temp, |
|
) |
|
|
|
import warnings |
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="transformers") |
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" |
|
|
|
|
|
from code_interpreter.LlamaCodeInterpreter import LlamaCodeInterpreter |
|
|
|
|
|
class StreamingLlamaCodeInterpreter(LlamaCodeInterpreter): |
|
streamer: Optional[TextIteratorStreamer] = None |
|
|
|
|
|
@torch.inference_mode() |
|
def generate( |
|
self, |
|
prompt: str = "[INST]\n###User : hi\n###Assistant :", |
|
max_new_tokens=512, |
|
do_sample: bool = True, |
|
use_cache: bool = True, |
|
top_p: float = 0.95, |
|
temperature: float = 0.1, |
|
top_k: int = 50, |
|
repetition_penalty: float = 1.0, |
|
) -> str: |
|
|
|
|
|
self.streamer = TextIteratorStreamer( |
|
self.tokenizer, skip_prompt=True, Timeout=5 |
|
) |
|
|
|
input_prompt = copy.deepcopy(prompt) |
|
inputs = self.tokenizer([prompt], return_tensors="pt") |
|
input_tokens_shape = inputs["input_ids"].shape[-1] |
|
|
|
eos_token_id = self.tokenizer.convert_tokens_to_ids(DEFAULT_EOS_TOKEN) |
|
e_code_token_id = self.tokenizer.convert_tokens_to_ids(E_CODE) |
|
|
|
kwargs = dict( |
|
**inputs, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=do_sample, |
|
top_p=top_p, |
|
temperature=temperature, |
|
use_cache=use_cache, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
eos_token_id=[ |
|
eos_token_id, |
|
e_code_token_id, |
|
], |
|
streamer=self.streamer, |
|
) |
|
|
|
thread = Thread(target=self.model.generate, kwargs=kwargs) |
|
thread.start() |
|
|
|
return "" |
|
|
|
|
|
def change_markdown_image(text: str): |
|
modified_text = re.sub(r"!\[(.*?)\]\(\'(.*?)\'\)", r"![\1](/file=\2)", text) |
|
return modified_text |
|
|
|
|
|
def gradio_launch(model_path: str, load_in_4bit: bool = True, MAX_TRY: int = 5): |
|
with gr.Blocks(theme=gr.themes.Monochrome()) as demo: |
|
chatbot = gr.Chatbot(height=820, avatar_images="./assets/logo2.png") |
|
msg = gr.Textbox() |
|
clear = gr.Button("Clear") |
|
|
|
interpreter = StreamingLlamaCodeInterpreter( |
|
model_path=model_path, load_in_4bit=load_in_4bit |
|
) |
|
|
|
def bot(history): |
|
user_message = history[-1][0] |
|
|
|
interpreter.dialog.append({"role": "user", "content": user_message}) |
|
|
|
print(f"###User : [bold]{user_message}[bold]") |
|
|
|
|
|
|
|
HAS_CODE = False |
|
INST_END_TOK_FLAG = False |
|
full_generated_text = "" |
|
prompt = interpreter.dialog_to_prompt(dialog=interpreter.dialog) |
|
start_prompt = copy.deepcopy(prompt) |
|
prompt = f"{prompt} {E_INST}" |
|
|
|
_ = interpreter.generate(prompt) |
|
history[-1][1] = "" |
|
generated_text = "" |
|
for character in interpreter.streamer: |
|
history[-1][1] += character |
|
generated_text += character |
|
yield history |
|
|
|
full_generated_text += generated_text |
|
HAS_CODE, generated_code_block = interpreter.extract_code_blocks( |
|
generated_text |
|
) |
|
|
|
attempt = 1 |
|
while HAS_CODE: |
|
if attempt > MAX_TRY: |
|
break |
|
|
|
|
|
|
|
history[-1][1] = ( |
|
history[-1][1] |
|
.replace(f"{B_CODE}", "\n```python\n") |
|
.replace(f"{E_CODE}", "\n```\n") |
|
) |
|
history[-1][1] = change_markdown_image(history[-1][1]) |
|
yield history |
|
|
|
|
|
generated_code_block = generated_code_block.replace( |
|
"<unk>_", "" |
|
).replace("<unk>", "") |
|
|
|
( |
|
code_block_output, |
|
error_flag, |
|
) = interpreter.execute_code_and_return_output( |
|
f"{generated_code_block}" |
|
) |
|
code_block_output = interpreter.clean_code_output(code_block_output) |
|
generated_text = ( |
|
f"{generated_text}\n{B_RESULT}\n{code_block_output}\n{E_RESULT}\n" |
|
) |
|
full_generated_text += ( |
|
f"\n{B_RESULT}\n{code_block_output}\n{E_RESULT}\n" |
|
) |
|
|
|
|
|
history[-1][1] += f"\n```RESULT\n{code_block_output}\n```\n" |
|
history[-1][1] = change_markdown_image(history[-1][1]) |
|
yield history |
|
|
|
prompt = f"{prompt} {generated_text}" |
|
|
|
_ = interpreter.generate(prompt) |
|
for character in interpreter.streamer: |
|
history[-1][1] += character |
|
generated_text += character |
|
history[-1][1] = change_markdown_image(history[-1][1]) |
|
yield history |
|
|
|
HAS_CODE, generated_code_block = interpreter.extract_code_blocks( |
|
generated_text |
|
) |
|
|
|
if generated_text.endswith("</s>"): |
|
break |
|
|
|
attempt += 1 |
|
|
|
interpreter.dialog.append( |
|
{ |
|
"role": "assistant", |
|
"content": generated_text.replace("<unk>_", "") |
|
.replace("<unk>", "") |
|
.replace("</s>", ""), |
|
} |
|
) |
|
|
|
print("----------\n" * 2) |
|
print(interpreter.dialog) |
|
print("----------\n" * 2) |
|
|
|
return history[-1][1] |
|
|
|
def user(user_message, history): |
|
return "", history + [[user_message, None]] |
|
|
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
bot, chatbot, chatbot |
|
) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
demo.queue() |
|
demo.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="Process path for LLAMA2_FINETUNEED.") |
|
parser.add_argument( |
|
"--path", |
|
type=str, |
|
required=True, |
|
help="Path to the finetuned LLAMA2 model.", |
|
default="./output/llama-2-7b-codellama-ci", |
|
) |
|
args = parser.parse_args() |
|
|
|
gradio_launch(model_path=args.path, load_in_4bit=True) |
|
|