import os from threading import Thread import gradio as gr import torch from transformers import (AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer) theme = gr.themes.Monochrome( primary_hue="indigo", secondary_hue="blue", neutral_hue="slate", radius_size=gr.themes.sizes.radius_sm, font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"], ) HF_TOKEN = os.environ.get("HF_TOKEN", None) os.environ["TOKENIZERS_PARALLELISM"] = "false" device = "cuda" if torch.cuda.is_available() else "cpu" model_id = "trl-lib/llama-se-rl-merged" if device == "cpu": model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, use_auth_token=HF_TOKEN) else: model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", load_in_8bit=True, use_auth_token=HF_TOKEN ) tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN) PROMPT_TEMPLATE = """Question: {prompt}\n\nAnswer: """ def generate(instruction, temperature=1, max_new_tokens=256, top_p=1, top_k=0): formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction) streamer = TextIteratorStreamer(tokenizer) model_inputs = tokenizer(formatted_instruction, return_tensors="pt", truncation=True, max_length=2048).to(device) generate_kwargs = dict( top_p=top_p, temperature=temperature, max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, ) t = Thread(target=model.generate, kwargs={**dict(model_inputs, streamer=streamer), **generate_kwargs}) t.start() output = "" hidden_output = "" for new_text in streamer: # skip streaming until new text is available if len(hidden_output) <= len(formatted_instruction): hidden_output += new_text continue # replace eos token if tokenizer.eos_token in new_text: new_text = new_text.replace(tokenizer.eos_token, "") output += new_text yield output return output examples = [ "How do I create an array in C++ of length 5 which contains all even numbers between 1 and 10?", "How can I write a Java function to generate the nth Fibonacci number?", "How can I write a Python function that checks if a given number is a palindrome or not?", "I have a lion in my garden. How can I get rid of it?", ] def process_example(args): for x in generate(args): pass return x with gr.Blocks(theme=theme) as demo: with gr.Column(): gr.Markdown( """