KOMUChat / koalpaca.py
elplaguister
Update koalpaca.py
926b662
raw
history blame
1.77 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig
from peft import PeftModel, PeftConfig
from model import Model
from accelerate import Accelerator
class KoAlpaca(Model):
def __init__(self):
peft_model_id = "4n3mone/Komuchat-koalpaca-polyglot-12.8B"
config = PeftConfig.from_pretrained(peft_model_id)
# self.accelerator = Accelerator()
self.bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
#self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=self.bnb_config, device_map={"":0})
self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=self.bnb_config, device_map='auto')
self.model = PeftModel.from_pretrained(self.model, peft_model_id)
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
self.gen_config = GenerationConfig.from_pretrained('./models/koalpaca', 'gen_config.json')
self.INPUT_FORMAT = "### 질문: <INPUT>\n\n### 답변:"
self.model.eval()
def generate(self, inputs):
inputs = self.INPUT_FORMAT.replace('<INPUT>', inputs)
output_ids = self.model.generate(
**self.tokenizer(
inputs,
return_tensors='pt',
return_token_type_ids=False
).to(self.model.device),
generation_config=self.gen_config
)
outputs = self.tokenizer.decode(output_ids[0]).split("### 답변: ")[-1]
return outputs