import torch import gradio as gr from threading import Thread from peft import PeftModel, PeftConfig from unsloth import FastLanguageModel from transformers import TextStreamer from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer config = PeftConfig.from_pretrained("bilgee/Llama-3.1-8B-MN_Instruct") model = AutoModelForCausalLM.from_pretrained("unsloth/llama-3-8b", torch_dtype = torch.float16) model = PeftModel.from_pretrained(model, "bilgee/Llama-3.1-8B-MN_Instruct") #load tokenizer tokenizer = AutoTokenizer.from_pretrained("bilgee/Llama-3.1-8B-MN_Instruct") alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. ### Instruction: {} ### Input: {} ### Response: {}""" # Enable native 2x faster inference FastLanguageModel.for_inference(model) # Create a text streamer text_streamer = TextStreamer(tokenizer, skip_prompt=False,skip_special_tokens=True) # Get the device based on GPU availability device = 'cuda' if torch.cuda.is_available() else 'cpu' # Move model into device model = model.to(device) class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = [29, 0] for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False # Current implementation does not support conversation based on previous conversation. # Highly recommend to experiment on various hyper parameters to compare qualities. def predict(message, history): stop = StopOnTokens() messages = alpaca_prompt.format( message, "", "", ) model_inputs = tokenizer([messages], return_tensors="pt").to(device) streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=1024, top_p=0.95, temperature=0.001, repetition_penalty=1.1, stopping_criteria=StoppingCriteriaList([stop]) ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() partial_message = "" for new_token in streamer: if new_token != '<': partial_message += new_token yield partial_message gr.ChatInterface(predict).launch(debug=True, share=True, show_api=True)