Spaces:
Sleeping
Sleeping
import os | |
import re | |
import torch | |
import datetime | |
import json | |
import csv | |
import gc | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
from transformers import TextStreamer, TextIteratorStreamer | |
from transformers import GenerationConfig, AutoConfig, GPTQConfig, AwqConfig | |
from models import models | |
tokenizer = None | |
model = None | |
loaded_model_name = None | |
loaded_dtype = None | |
def load_model(model_name, dtype = 'int4'): | |
global tokenizer, model, loaded_model_name, loaded_dtype | |
if loaded_model_name == model_name and loaded_dtype == dtype: | |
return | |
del model | |
del tokenizer | |
model = None | |
tokenizer = None | |
gc.collect() | |
torch.cuda.empty_cache() | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
if dtype == 'int4': | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
trust_remote_code=True, | |
quantization_config=BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
), | |
) | |
elif dtype == 'int8': | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
trust_remote_code=True, | |
quantization_config=BitsAndBytesConfig( | |
torch_dtype=torch.bfloat16, | |
load_in_8bit=True, | |
), | |
) | |
elif dtype == 'fp16': | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
) | |
elif dtype == 'bf16': | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16, | |
) | |
else: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
device_map="auto", | |
) | |
loaded_model_name = model_name | |
loaded_dtype = dtype | |
def infer(args: dict): | |
global tokenizer, model, loaded_model_name | |
if 'model' in args: | |
args['model_name'] = args['model'] | |
if not tokenizer or 'model_name' in args and loaded_model_name != args['model_name']: | |
if 'dtype' in args: | |
load_model(args['model_name'], args['dtype']) | |
else: | |
load_model(args['model_name']) | |
config = {} | |
if args['model_name'] in models: | |
config = models[args['model_name']] | |
config.update(args) | |
if config['is_messages']: | |
messages = [] | |
messages.append({"role": "system", "content": args['instruction']}) | |
if args['input']: | |
messages.append({"role": "user", "content": args['input']}) | |
tprompt = tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False) | |
else: | |
tprompt = config['template'].format(bos_token=tokenizer.bos_token, instruction=args['instruction'], input=args['input']) | |
kwargs = config.copy() | |
for k in ['model_name', 'template', 'instruction', 'input', 'location', 'endpoint', 'model', 'dtype', 'is_messages']: | |
if k in kwargs: | |
del kwargs[k] | |
with torch.no_grad(): | |
token_ids = tokenizer.encode(tprompt, add_special_tokens=False, return_tensors="pt") | |
if config['is_messages']: | |
output_ids = model.generate( | |
input_ids=token_ids.to(model.device), | |
do_sample=True, | |
**kwargs, | |
) | |
else: | |
output_ids = model.generate( | |
input_ids=token_ids.to(model.device), | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id, | |
bos_token_id=tokenizer.bos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
**kwargs, | |
) | |
out = output_ids.tolist()[0][token_ids.size(1) :] | |
content = tokenizer.decode(out, skip_special_tokens=True) | |
return content | |