|
import transformers |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
|
import torch |
|
|
|
|
|
|
|
MODEL = "Viet-Mistral/Vistral-7B-Chat" |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print('device =', device) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
'Viet-Mistral/Vistral-7B-Chat', |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto" |
|
|
|
|
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL, cache_dir='/workspace/thviet/hf_cache') |
|
|
|
lora_config = LoraConfig.from_pretrained( |
|
"thviet79/model-QA-medical" |
|
|
|
) |
|
|
|
|
|
model = get_peft_model(model, lora_config) |
|
|
|
system_prompt = "Bạn là một trợ lí ảo Tiếng Việt về lĩnh vực y tế." |
|
question = "Chào bác sĩ,\nRăng cháu hiện tại có mủ ở dưới lợi nhưng khi đau cháu sẽ không ngủ được (quá đau). Tuy nhiên chỉ vài ngày là hết mà thỉnh thoảng nó lại bị đau. Chị cháu bảo là trước chị cháu cũng bị như vậy chỉ là đau răng tuổi dậy thì thôi. Bác sĩ cho cháu hỏi đau răng kèm có mủ dưới lợi là bệnh gì? Cháu có cần đi chữa trị không? Cháu cảm ơn." |
|
|
|
conversation = [{"role": "system", "content": system_prompt }] |
|
human = f"Vui lòng trả lời câu hỏi sau: {question}" |
|
conversation.append({"role": "user", "content": human }) |
|
|
|
|
|
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(device) |
|
|
|
|
|
out_ids = model.generate( |
|
input_ids=input_ids, |
|
max_new_tokens=768, |
|
do_sample=True, |
|
top_p=0.95, |
|
top_k=40, |
|
temperature=0.1, |
|
repetition_penalty=1.05, |
|
) |
|
|
|
|
|
assistant = tokenizer.batch_decode(out_ids[:, input_ids.size(1):], skip_special_tokens=True)[0].strip() |
|
print("Assistant: ", assistant) |
|
|