Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
# Load the model and tokenizer | |
model_path = "WhiteRabbitNeo/WhiteRabbitNeo-13B-v1" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
load_in_4bit=False, | |
load_in_8bit=True, | |
trust_remote_code=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
# Chatbot prompt and conversation history | |
tot_system_prompt = """ | |
Answer the Question by exploring multiple reasoning paths as follows: | |
- First, carefully analyze the question to extract the key information components and break it down into logical sub-questions... | |
""" | |
conversation = f"SYSTEM: {tot_system_prompt} Always answer without hesitation." | |
# Text generation function | |
def generate_text(instruction): | |
tokens = tokenizer.encode(instruction) | |
tokens = torch.LongTensor(tokens).unsqueeze(0) | |
tokens = tokens.to("cuda") | |
instance = { | |
"input_ids": tokens, | |
"top_p": 1.0, | |
"temperature": 0.5, | |
"generate_len": 1024, | |
"top_k": 50, | |
} | |
length = len(tokens[0]) | |
with torch.no_grad(): | |
rest = model.generate( | |
input_ids=tokens, | |
max_length=length + instance["generate_len"], | |
use_cache=True, | |
do_sample=True, | |
top_p=instance["top_p"], | |
temperature=instance["temperature"], | |
top_k=instance["top_k"], | |
num_return_sequences=1, | |
) | |
output = rest[0][length:] | |
string = tokenizer.decode(output, skip_special_tokens=True) | |
answer = string.split("USER:")[0].strip() | |
return answer | |
# Gradio interface function | |
def chatbot(user_input, chat_history): | |
global conversation | |
llm_prompt = f"{conversation} \nUSER: {user_input} \nASSISTANT: " | |
answer = generate_text(llm_prompt) | |
conversation = f"{llm_prompt}{answer}" # Update conversation history | |
chat_history.append((user_input, answer)) # Update chat history | |
return chat_history, chat_history | |
# Initialize Gradio | |
with gr.Blocks() as demo: | |
gr.Markdown("## Chat with WhiteRabbitNeo!") | |
chatbot_interface = gr.Chatbot() | |
msg = gr.Textbox(label="Your Message") | |
clear = gr.Button("Clear Chat") | |
chat_history_state = gr.State([]) # Maintain chat history as state | |
# Define button functionality | |
msg.submit(chatbot, inputs=[msg, chat_history_state], outputs=[chatbot_interface, chat_history_state]) | |
clear.click(lambda: ([], []), outputs=[chatbot_interface, chat_history_state]) # Clear chat history | |
# Launch the app | |
demo.launch() |