import os import time import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import gradio as gr from threading import Thread MODEL = "THUDM/LongWriter-llama3.1-8b" TITLE = "

LongWriter-llama3.1-8b

" PLACEHOLDER = """

Hi! I'm LongWriter, capable of generating 10,000+ words. How can I assist you today?

""" CSS = """ .duplicate-button { margin: auto !important; color: white !important; background: black !important; border-radius: 100vh !important; } h3 { text-align: center; } """ device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto") model = model.eval() @spaces.GPU() def stream_chat( message: str, history: list, system_prompt: str, temperature: float = 0.5, max_new_tokens: int = 32768, top_p: float = 1.0, top_k: int = 50, ): print(f'message: {message}') print(f'history: {history}') full_prompt = f"<>\n{system_prompt}\n<>\n\n" for prompt, answer in history: full_prompt += f"[INST]{prompt}[/INST]{answer}" full_prompt += f"[INST]{message}[/INST]" inputs = tokenizer(full_prompt, truncation=False, return_tensors="pt").to(device) context_length = inputs.input_ids.shape[-1] streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( inputs=inputs.input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, streamer=streamer, ) thread = Thread(target=model.generate, kwargs=generate_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield buffer chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER) with gr.Blocks(css=CSS, theme="soft") as demo: gr.HTML(TITLE) gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") gr.ChatInterface( fn=stream_chat, chatbot=chatbot, fill_height=True, additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), additional_inputs=[ gr.Textbox( value="You are a helpful assistant capable of generating long-form content.", label="System Prompt", render=False, ), gr.Slider( minimum=0, maximum=1, step=0.1, value=0.5, label="Temperature", render=False, ), gr.Slider( minimum=1024, maximum=32768, step=1024, value=32768, label="Max new tokens", render=False, ), gr.Slider( minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="Top p", render=False, ), gr.Slider( minimum=1, maximum=100, step=1, value=50, label="Top k", render=False, ), ], examples=[ ["Write a 5000-word comprehensive guide on machine learning for beginners."], ["Create a detailed 3000-word business plan for a sustainable energy startup."], ["Compose a 2000-word short story set in a futuristic underwater city."], ["Develop a 4000-word research proposal on the potential effects of climate change on global food security."], ], cache_examples=False, ) if __name__ == "__main__": demo.launch()