import gradio as gr import torch import llm_blender from transformers import ( AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList, ) from accelerate import infer_auto_device_map from typing import List from model_utils import build_tokenizer, build_model, get_llm_prompt, get_stop_str_and_ids BASE_LLM_NAMES = [ "chavinlo/alpaca-native", "eachadea/vicuna-13b-1.1", "databricks/dolly-v2-12b", "stabilityai/stablelm-tuned-alpha-7b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", "TheBloke/koala-13B-HF", "project-baize/baize-v2-13b", "google/flan-t5-xxl", "THUDM/chatglm-6b", "fnlp/moss-moon-003-sft", "mosaicml/mpt-7b-chat", ] BASE_LLM_MODELS = { name: None for name in BASE_LLM_NAMES } BASE_LLM_TOKENIZERS = { name: None for name in BASE_LLM_NAMES } class StopTokenIdsCriteria(StoppingCriteria): """ This class can be used to stop generation whenever the generated number of tokens exceeds `max_new_tokens`. Keep in mind for decoder-only type of transformers, this will **not** include the initial prompted tokens. This is very close to `MaxLengthCriteria` but ignores the number of initial tokens. Args: stop_token_ids (`List[int]`): """ def __init__(self, stop_token_ids: List[int]): self.stop_token_ids = stop_token_ids def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: if self.stop_token_ids: return all(_input_ids[-1] in self.stop_token_ids for _input_ids in input_ids) return False def llm_generate( base_llm_name:str, instruction:str, input:str, max_new_tokens:int, top_p=1.0, temperature=0.7, ) -> str: if BASE_LLM_MODELS.get(base_llm_name, None) is None: BASE_LLM_MODELS[base_llm_name] = build_model( base_llm_name, device_map="auto", load_in_8bit=True, trust_remote_code=True) if BASE_LLM_TOKENIZERS.get(base_llm_name, None) is None: BASE_LLM_TOKENIZERS[base_llm_name] = build_tokenizer( base_llm_name, trust_remote_code=True) base_llm = BASE_LLM_MODELS[base_llm_name] base_llm_tokenizer = BASE_LLM_TOKENIZERS[base_llm_name] llm_prompt = get_llm_prompt(base_llm_name, instruction, input) stop_str, stop_token_ids = get_stop_str_and_ids(base_llm_tokenizer) template_length = len(base_llm_tokenizer.encode( llm_prompt.replace(instruction, "").replace(input, ""))) encoded_llm_prompt = base_llm_tokenizer(llm_prompt, max_length=256 + template_length, padding='max_length', truncation=True, return_tensors="pt") input_ids = encoded_llm_prompt["input_ids"].to(base_llm.device) attention_mask = encoded_llm_prompt["attention_mask"].to(base_llm.device) generate_kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, "max_new_tokens": max_new_tokens, "do_sample": True, "top_p": top_p, "temperature": temperature, "num_return_sequences": 1, } if stop_token_ids: generate_kwargs['stopping_criteria'] = StoppingCriteriaList([ StopTokenIdsCriteria(stop_token_ids), ]) output_ids = base_llm.generate(**generate_kwargs) output_ids_wo_prompt = output_ids[0, input_ids.shape[1]:] decoded_output = base_llm_tokenizer.decode(output_ids_wo_prompt, skip_special_tokens=True) if stop_str: pos = decoded_output.find(stop_str) if pos != -1: decoded_output = decoded_output[:pos] return decoded_output def llms_generate( base_llm_names, instruction, input, max_new_tokens, top_p=1.0, temperature=0.7, ): return { base_llm_name: llm_generate( base_llm_name, instruction, input, max_new_tokens, top_p, temperature) for base_llm_name in base_llm_names }