Spaces:
Paused
Paused
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig | |
from peft import PeftModel, PeftConfig | |
from model import Model | |
class KoAlpaca(Model): | |
def __init__(self): | |
peft_model_id = "4n3mone/Komuchat-koalpaca-polyglot-12.8B" | |
config = PeftConfig.from_pretrained(peft_model_id) | |
accelerator = Accelerator() | |
self.bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
#self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=self.bnb_config, device_map={"":0}) | |
self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=self.bnb_config, device_map='auto') | |
self.model = PeftModel.from_pretrained(self.model, peft_model_id) | |
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) | |
self.gen_config = GenerationConfig.from_pretrained('./models/koalpaca', 'gen_config.json') | |
self.INPUT_FORMAT = "### 질문: <INPUT>\n\n### 답변:" | |
self.model.eval() | |
def generate(self, inputs): | |
inputs = self.INPUT_FORMAT.replace('<INPUT>', inputs) | |
output_ids = self.model.generate( | |
**self.tokenizer( | |
inputs, | |
return_tensors='pt', | |
return_token_type_ids=False | |
).to(accelerator.device), | |
generation_config=self.gen_config | |
) | |
outputs = self.tokenizer.decode(output_ids[0]).split("### 답변: ")[-1] | |
return outputs |