import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline, logging, TextStreamer from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model import os, torch, wandb, platform, warnings from datasets import load_dataset from trl import SFTTrainer hf_token = '..........' tokenizer = AutoTokenizer.from_pretrained('./vistral-tokenizer') bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( 'Viet-Mistral/Vistral-7B-Chat', device_map="auto", token=hf_token, quantization_config=bnb_config, ) ft_model = PeftModel.from_pretrained(model, CHECKPOINT_PATH) #torch.backends.cuda.enable_mem_efficient_sdp(False) #torch.backends.cuda.enable_flash_sdp(False) system_prompt = "Bạn là một trợ lí Tiếng Việt nhiệt tình và trung thực. Hãy luôn trả lời một cách hữu ích nhất có thể, đồng thời giữ an toàn." stop_tokens = [tokenizer.eos_token_id, tokenizer('<|im_end|>')['input_ids'].pop()] def chat_test(): conversation = [{"role": "system", "content": system_prompt }] while True: human = input("Human: ") if human.lower() == "reset": conversation = [{"role": "system", "content": system_prompt }] print("The chat history has been cleared!") continue if human.lower() == "exit": break conversation.append({"role": "user", "content": human }) formatted = tokenizer.apply_chat_template(conversation, tokenize=False) + "<|im_start|>assistant" tok = tokenizer(formatted, return_tensors="pt").to(ft_model.device) input_ids = tok['input_ids'] out_ids = ft_model.generate( input_ids=input_ids, attention_mask=tok['attention_mask'], eos_token_id=stop_tokens, max_new_tokens=50, do_sample=True, top_p=0.95, top_k=40, temperature=0.1, repetition_penalty=1.05, ) assistant = tokenizer.batch_decode(out_ids[:, input_ids.size(1): ], skip_special_tokens=True)[0].strip() print("Assistant: ", assistant) conversation.append({"role": "assistant", "content": assistant }) chat_test()