import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import time
import numpy as np
from torch.nn import functional as F
import os
from threading import Thread

print(f"Starting to load the model to memory")
m = AutoModelForCausalLM.from_pretrained(
    "stabilityai/stablelm-2-1_6b-zephyr", torch_dtype=torch.float16, trust_remote_code=True)
tok = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-1_6b-zephyr", trust_remote_code=True)
generator = pipeline('text-generation', model=m, tokenizer=tok)
print(f"Sucessfully loaded the model to the memory")

start_message = ""

def user(message, history):
    # Append the user's message to the conversation history
    return "", history + [[message, ""]]

def chat(history):
    chat = []
    for item in history:
        chat.append({"role": "user", "content": item[0]})
        if item[1] is not None:
            chat.append({"role": "assistant", "content": item[0]})
    messages = tokenizer.apply_chat_template(chat, tokenize=False)
    # Tokenize the messages string
    model_inputs = tok([messages], return_tensors="pt")
    streamer = TextIteratorStreamer(
        tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        top_k=1000,
        temperature=0.75,
        num_beams=1,
    )
    t = Thread(target=m.generate, kwargs=generate_kwargs)
    t.start()

    # print(history)
    # Initialize an empty string to store the generated text
    partial_text = ""
    for new_text in streamer:
        # print(new_text)
        partial_text += new_text
        history[-1][1] = partial_text
        # Yield an empty string to cleanup the message textbox and the updated conversation history
        yield history
    return partial_text

with gr.Blocks() as demo:
    # history = gr.State([])
    gr.Markdown("## Stable LM 1.6b Zephyr")
''') chatbot = gr.Chatbot().style(height=500) with gr.Row(): with gr.Column(): msg = gr.Textbox(label="Chat Message Box", placeholder="Chat Message Box", show_label=False).style(container=False) with gr.Column(): with gr.Row(): submit = gr.Button("Submit") stop = gr.Button("Stop") clear = gr.Button("Clear") submit_event = msg.submit(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( fn=chat, inputs=[chatbot], outputs=[chatbot], queue=True) submit_click_event =, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( fn=chat, inputs=[chatbot], outputs=[chatbot], queue=True), inputs=None, outputs=None, cancels=[ submit_event, submit_click_event], queue=False) None, None, [chatbot], queue=False) demo.queue(max_size=32, concurrency_count=2) demo.launch()