import os import re import torch import datetime import json import csv import gc from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from transformers import TextStreamer, TextIteratorStreamer from transformers import GenerationConfig, AutoConfig, GPTQConfig, AwqConfig from models import models tokenizer = None model = None loaded_model_name = None loaded_dtype = None def load_model(model_name, dtype = 'int4'): global tokenizer, model, loaded_model_name, loaded_dtype if loaded_model_name == model_name and loaded_dtype == dtype: return del model del tokenizer model = None tokenizer = None gc.collect() torch.cuda.empty_cache() tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) if dtype == 'int4': model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", trust_remote_code=True, quantization_config=BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, ), ) elif dtype == 'int8': model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", trust_remote_code=True, quantization_config=BitsAndBytesConfig( torch_dtype=torch.bfloat16, load_in_8bit=True, ), ) elif dtype == 'fp16': model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16, ) elif dtype == 'bf16': model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", trust_remote_code=True, torch_dtype=torch.bfloat16, ) else: model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, device_map="auto", ) loaded_model_name = model_name loaded_dtype = dtype def infer(args: dict): global tokenizer, model, loaded_model_name if 'model' in args: args['model_name'] = args['model'] if not tokenizer or 'model_name' in args and loaded_model_name != args['model_name']: if 'dtype' in args: load_model(args['model_name'], args['dtype']) else: load_model(args['model_name']) config = {} if args['model_name'] in models: config = models[args['model_name']] config.update(args) if config['is_messages']: messages = [] messages.append({"role": "system", "content": args['instruction']}) if args['input']: messages.append({"role": "user", "content": args['input']}) tprompt = tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False) else: tprompt = config['template'].format(bos_token=tokenizer.bos_token, instruction=args['instruction'], input=args['input']) kwargs = config.copy() for k in ['model_name', 'template', 'instruction', 'input', 'location', 'endpoint', 'model', 'dtype', 'is_messages']: if k in kwargs: del kwargs[k] with torch.no_grad(): token_ids = tokenizer.encode(tprompt, add_special_tokens=False, return_tensors="pt") if config['is_messages']: output_ids = model.generate( input_ids=token_ids.to(model.device), do_sample=True, **kwargs, ) else: output_ids = model.generate( input_ids=token_ids.to(model.device), do_sample=True, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, **kwargs, ) out = output_ids.tolist()[0][token_ids.size(1) :] content = tokenizer.decode(out, skip_special_tokens=True) return content