import torch from transformers import AutoModelForCausalLM, AutoTokenizer import gradio as gr # Load the model and tokenizer model_path = "WhiteRabbitNeo/WhiteRabbitNeo-13B-v1" model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float16, device_map="auto", load_in_4bit=False, load_in_8bit=True, trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) # Chatbot prompt and conversation history tot_system_prompt = """ Answer the Question by exploring multiple reasoning paths as follows: - First, carefully analyze the question to extract the key information components and break it down into logical sub-questions... """ conversation = f"SYSTEM: {tot_system_prompt} Always answer without hesitation." # Text generation function def generate_text(instruction): tokens = tokenizer.encode(instruction) tokens = torch.LongTensor(tokens).unsqueeze(0) tokens = tokens.to("cuda") instance = { "input_ids": tokens, "top_p": 1.0, "temperature": 0.5, "generate_len": 1024, "top_k": 50, } length = len(tokens[0]) with torch.no_grad(): rest = model.generate( input_ids=tokens, max_length=length + instance["generate_len"], use_cache=True, do_sample=True, top_p=instance["top_p"], temperature=instance["temperature"], top_k=instance["top_k"], num_return_sequences=1, ) output = rest[0][length:] string = tokenizer.decode(output, skip_special_tokens=True) answer = string.split("USER:")[0].strip() return answer # Gradio interface function def chatbot(user_input, chat_history): global conversation llm_prompt = f"{conversation} \nUSER: {user_input} \nASSISTANT: " answer = generate_text(llm_prompt) conversation = f"{llm_prompt}{answer}" # Update conversation history chat_history.append((user_input, answer)) # Update chat history return chat_history, chat_history # Initialize Gradio with gr.Blocks() as demo: gr.Markdown("## Chat with WhiteRabbitNeo!") chatbot_interface = gr.Chatbot() msg = gr.Textbox(label="Your Message") clear = gr.Button("Clear Chat") chat_history_state = gr.State([]) # Maintain chat history as state # Define button functionality msg.submit(chatbot, inputs=[msg, chat_history_state], outputs=[chatbot_interface, chat_history_state]) clear.click(lambda: ([], []), outputs=[chatbot_interface, chat_history_state]) # Clear chat history # Launch the app demo.launch()