import torch import transformers from torch import cuda from accelerate import dispatch_model, infer_auto_device_map from accelerate.utils import get_balanced_memory from transformers import BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList from typing import Dict, List, Any class PreTrainedPipeline(): def __init__(self, path=""): path = "oleksandrfluxon/mpt-7b-instruct-evaluate" print("===> path", path) device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu' print("===> device", device) model = transformers.AutoModelForCausalLM.from_pretrained( 'oleksandrfluxon/mpt-7b-instruct-evaluate', trust_remote_code=True, load_in_8bit=True, # this requires the `bitsandbytes` library max_seq_len=8192, init_device=device ) model.eval() #model.to(device) print(f"===> Model loaded on {device}") tokenizer = transformers.AutoTokenizer.from_pretrained("mosaicml/mpt-7b") # we create a list of stopping criteria stop_token_ids = [ tokenizer.convert_tokens_to_ids(x) for x in [ ['Human', ':'], ['AI', ':'] ] ] stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids] print("===> stop_token_ids", stop_token_ids) # define custom stopping criteria object class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for stop_ids in stop_token_ids: if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all(): return True return False stopping_criteria = StoppingCriteriaList([StopOnTokens()]) self.pipeline = transformers.pipeline( model=model, tokenizer=tokenizer, return_full_text=True, # langchain expects the full text task='text-generation', # we pass model parameters here too stopping_criteria=stopping_criteria, # without this model rambles during chat temperature=0.1, # 'randomness' of outputs, 0.0 is the min and 1.0 the max top_p=0.15, # select from top tokens whose probability add up to 15% top_k=0, # select from top 0 tokens (because zero, relies on top_p) max_new_tokens=128, # mex number of tokens to generate in the output repetition_penalty=1.1 # without this output begins repeating ) print("===> init finished") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str`) parameters (:obj: `str`) Return: A :obj:`str`: todo """ # get inputs inputs = data.pop("inputs",data) parameters = data.pop("parameters", {}) date = data.pop("date", None) print("===> inputs", inputs) print("===> parameters", parameters) result = self.pipeline(inputs, **parameters) print("===> result", result) return result