import os from threading import Thread from typing import Iterator import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) DESCRIPTION = """\ # Llama-3.1 8B Chat Education """ if not torch.cuda.is_available(): DESCRIPTION += "\n

Running on CPU 🥶 This demo does not work on GPU.

" model_id = "minhquy1624/SFT-LLAMA3-8B-Education" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype= torch.bfloat16).to("cpu") tokenizer = AutoTokenizer.from_pretrained(model_id) def generate( message: str, history: list = None, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.95, top_k: int = 50, repetition_penalty: float = 1.05, ) -> Iterator[str]: conversation = [ { "content": "Bạn là một chuyên gia trong lĩnh vực giáo dục và không được đưa ra thông tin sai lệch, sai sự thật. Nếu câu hỏi đó bạn không biết hãy trả lời 'Tôi không biết'. Mỗi câu trả lời đưa ra phải dựa trên văn bản tham chiếu.", "role": "system" }, { "content": message, "role": "user" }, ] input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {"input_ids": input_ids}, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id= tokenizer.eos_token_id ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) chat_interface = gr.ChatInterface( fn=generate, additional_inputs=[ gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, ), gr.Slider( label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6, ), gr.Slider( label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9, ), gr.Slider( label="Top-k", minimum=1, maximum=1000, step=1, value=50, ), gr.Slider( label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2, ), ], stop_btn=None, examples=[ ["Tôi muốn biết người nước ngoài muốn học thạc sĩ tại Đại học Quốc gia Hà Nội phải tuân theo quy định nào?"], ["Tôi muốn hỏi hệ thống hỗ trợ tuyển sinh chung của Bộ GD&ĐT có vai trò gì trong quy trình tuyển sinh tại ĐHQGHN?"], ["Tôi muốn hỏi xử lý nguyện vọng nghĩa là gì trong quy trình tuyển sinh của ĐHQGHN?"], ["Em có thể sử dụng kết quả thi tốt nghiệp THPT để xét tuyển vào ĐHQGHN không?"], ["Tôi muốn hỏi ĐHQGHN có thể điều chỉnh phương thức xét tuyển trong những năm tiếp theo không?"], ], cache_examples=False, type="messages", ) with gr.Blocks(fill_height=True) as demo: gr.Markdown(DESCRIPTION) gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") chat_interface.render() if __name__ == "__main__": demo.queue(max_size=20).launch()