|
from abc import ABC |
|
|
|
from langchain.llms.base import LLM |
|
import random |
|
import torch |
|
import transformers |
|
from transformers.generation.logits_process import LogitsProcessor |
|
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList |
|
from typing import Optional, List, Dict, Any |
|
from models.loader import LoaderCheckPoint |
|
from models.base import (BaseAnswer, |
|
AnswerResult) |
|
|
|
|
|
class InvalidScoreLogitsProcessor(LogitsProcessor): |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
if torch.isnan(scores).any() or torch.isinf(scores).any(): |
|
scores.zero_() |
|
scores[..., 5] = 5e4 |
|
return scores |
|
|
|
|
|
class LLamaLLM(BaseAnswer, LLM, ABC): |
|
checkPoint: LoaderCheckPoint = None |
|
|
|
history_len: int = 3 |
|
max_new_tokens: int = 500 |
|
num_beams: int = 1 |
|
temperature: float = 0.5 |
|
top_p: float = 0.4 |
|
top_k: int = 10 |
|
repetition_penalty: float = 1.2 |
|
encoder_repetition_penalty: int = 1 |
|
min_length: int = 0 |
|
logits_processor: LogitsProcessorList = None |
|
stopping_criteria: Optional[StoppingCriteriaList] = None |
|
eos_token_id: Optional[int] = [2] |
|
|
|
state: object = {'max_new_tokens': 50, |
|
'seed': 1, |
|
'temperature': 0, 'top_p': 0.1, |
|
'top_k': 40, 'typical_p': 1, |
|
'repetition_penalty': 1.2, |
|
'encoder_repetition_penalty': 1, |
|
'no_repeat_ngram_size': 0, |
|
'min_length': 0, |
|
'penalty_alpha': 0, |
|
'num_beams': 1, |
|
'length_penalty': 1, |
|
'early_stopping': False, 'add_bos_token': True, 'ban_eos_token': False, |
|
'truncation_length': 2048, 'custom_stopping_strings': '', |
|
'cpu_memory': 0, 'auto_devices': False, 'disk': False, 'cpu': False, 'bf16': False, |
|
'load_in_8bit': False, 'wbits': 'None', 'groupsize': 'None', 'model_type': 'None', |
|
'pre_layer': 0, 'gpu_memory_0': 0} |
|
|
|
def __init__(self, checkPoint: LoaderCheckPoint = None): |
|
super().__init__() |
|
self.checkPoint = checkPoint |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
return "LLamaLLM" |
|
|
|
@property |
|
def _check_point(self) -> LoaderCheckPoint: |
|
return self.checkPoint |
|
|
|
def encode(self, prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None): |
|
input_ids = self.checkPoint.tokenizer.encode(str(prompt), return_tensors='pt', |
|
add_special_tokens=add_special_tokens) |
|
|
|
if not add_bos_token and input_ids[0][0] == self.checkPoint.tokenizer.bos_token_id: |
|
input_ids = input_ids[:, 1:] |
|
|
|
|
|
|
|
if type(self.checkPoint.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871: |
|
input_ids = input_ids[:, 1:] |
|
|
|
|
|
if truncation_length is not None: |
|
input_ids = input_ids[:, -truncation_length:] |
|
|
|
return input_ids.cuda() |
|
|
|
def decode(self, output_ids): |
|
reply = self.checkPoint.tokenizer.decode(output_ids, skip_special_tokens=True) |
|
return reply |
|
|
|
|
|
def history_to_text(self, query, history): |
|
""" |
|
历史对话软提示 |
|
这段代码首先定义了一个名为 history_to_text 的函数,用于将 self.history |
|
数组转换为所需的文本格式。然后,我们将格式化后的历史文本 |
|
再用 self.encode 将其转换为向量表示。最后,将历史对话向量与当前输入的对话向量拼接在一起。 |
|
:return: |
|
""" |
|
formatted_history = '' |
|
history = history[-self.history_len:] if self.history_len > 0 else [] |
|
if len(history) > 0: |
|
for i, (old_query, response) in enumerate(history): |
|
formatted_history += "### Human:{}\n### Assistant:{}\n".format(old_query, response) |
|
formatted_history += "### Human:{}\n### Assistant:".format(query) |
|
return formatted_history |
|
|
|
def prepare_inputs_for_generation(self, |
|
input_ids: torch.LongTensor): |
|
""" |
|
预生成注意力掩码和 输入序列中每个位置的索引的张量 |
|
# TODO 没有思路 |
|
:return: |
|
""" |
|
|
|
mask_positions = torch.zeros((1, input_ids.shape[1]), dtype=input_ids.dtype).to(self.checkPoint.model.device) |
|
|
|
attention_mask = self.get_masks(input_ids, input_ids.device) |
|
|
|
position_ids = self.get_position_ids( |
|
input_ids, |
|
device=input_ids.device, |
|
mask_positions=mask_positions |
|
) |
|
|
|
return input_ids, position_ids, attention_mask |
|
|
|
@property |
|
def _history_len(self) -> int: |
|
return self.history_len |
|
|
|
def set_history_len(self, history_len: int = 10) -> None: |
|
self.history_len = history_len |
|
|
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: |
|
print(f"__call:{prompt}") |
|
if self.logits_processor is None: |
|
self.logits_processor = LogitsProcessorList() |
|
self.logits_processor.append(InvalidScoreLogitsProcessor()) |
|
|
|
gen_kwargs = { |
|
"max_new_tokens": self.max_new_tokens, |
|
"num_beams": self.num_beams, |
|
"top_p": self.top_p, |
|
"do_sample": True, |
|
"top_k": self.top_k, |
|
"repetition_penalty": self.repetition_penalty, |
|
"encoder_repetition_penalty": self.encoder_repetition_penalty, |
|
"min_length": self.min_length, |
|
"temperature": self.temperature, |
|
"eos_token_id": self.checkPoint.tokenizer.eos_token_id, |
|
"logits_processor": self.logits_processor} |
|
|
|
|
|
input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens) |
|
|
|
|
|
|
|
gen_kwargs.update({'inputs': input_ids}) |
|
|
|
|
|
|
|
if self.stopping_criteria is None: |
|
self.stopping_criteria = transformers.StoppingCriteriaList() |
|
|
|
gen_kwargs.update({'stopping_criteria': self.stopping_criteria}) |
|
|
|
output_ids = self.checkPoint.model.generate(**gen_kwargs) |
|
new_tokens = len(output_ids[0]) - len(input_ids[0]) |
|
reply = self.decode(output_ids[0][-new_tokens:]) |
|
print(f"response:{reply}") |
|
print(f"+++++++++++++++++++++++++++++++++++") |
|
return reply |
|
|
|
def generatorAnswer(self, prompt: str, |
|
history: List[List[str]] = [], |
|
streaming: bool = False): |
|
|
|
|
|
softprompt = self.history_to_text(prompt,history=history) |
|
response = self._call(prompt=softprompt, stop=['\n###']) |
|
|
|
answer_result = AnswerResult() |
|
answer_result.history = history + [[prompt, response]] |
|
answer_result.llm_output = {"answer": response} |
|
yield answer_result |
|
|