File size: 3,901 Bytes
9123479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
import torch
import llm_blender
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    StoppingCriteria, StoppingCriteriaList,
)
from accelerate import infer_auto_device_map
from typing import List

from model_utils import build_tokenizer, build_model, get_llm_prompt, get_stop_str_and_ids
BASE_LLM_NAMES = [
    "chavinlo/alpaca-native",
    "eachadea/vicuna-13b-1.1",
    "databricks/dolly-v2-12b",
    "stabilityai/stablelm-tuned-alpha-7b",
    "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
    "TheBloke/koala-13B-HF",
    "project-baize/baize-v2-13b",
    "google/flan-t5-xxl",
    "THUDM/chatglm-6b",
    "fnlp/moss-moon-003-sft",
    "mosaicml/mpt-7b-chat",
]

BASE_LLM_MODELS = {
    name: None for name in BASE_LLM_NAMES
}
BASE_LLM_TOKENIZERS = {
    name: None for name in BASE_LLM_NAMES
}

class StopTokenIdsCriteria(StoppingCriteria):
    """
    This class can be used to stop generation whenever the generated number of tokens exceeds `max_new_tokens`. Keep in
    mind for decoder-only type of transformers, this will **not** include the initial prompted tokens. This is very
    close to `MaxLengthCriteria` but ignores the number of initial tokens.

    Args:
        stop_token_ids (`List[int]`):
    """

    def __init__(self, stop_token_ids: List[int]):
        self.stop_token_ids = stop_token_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if self.stop_token_ids:
            return all(_input_ids[-1] in self.stop_token_ids for _input_ids in input_ids)
        return False
    
def llm_generate(
    base_llm_name:str, instruction:str, input:str, 
    max_new_tokens:int, top_p=1.0, temperature=0.7,
) -> str:
    if BASE_LLM_MODELS.get(base_llm_name, None) is None:
        BASE_LLM_MODELS[base_llm_name] = build_model(
            base_llm_name, device_map="auto", 
            load_in_8bit=True, trust_remote_code=True)
    if BASE_LLM_TOKENIZERS.get(base_llm_name, None) is None:
        BASE_LLM_TOKENIZERS[base_llm_name] = build_tokenizer(
            base_llm_name, trust_remote_code=True)
    base_llm = BASE_LLM_MODELS[base_llm_name]
    base_llm_tokenizer = BASE_LLM_TOKENIZERS[base_llm_name]
    llm_prompt = get_llm_prompt(base_llm_name, instruction, input)
    stop_str, stop_token_ids = get_stop_str_and_ids(base_llm_tokenizer)

    template_length = len(base_llm_tokenizer.encode(
        llm_prompt.replace(instruction, "").replace(input, "")))
    
    encoded_llm_prompt = base_llm_tokenizer(llm_prompt, 
        max_length=256 + template_length, 
        padding='max_length', truncation=True, return_tensors="pt")

    input_ids = encoded_llm_prompt["input_ids"].to(base_llm.device)
    attention_mask = encoded_llm_prompt["attention_mask"].to(base_llm.device)

    generate_kwargs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "max_new_tokens": max_new_tokens,
        "do_sample": True,
        "top_p": top_p,
        "temperature": temperature,
        "num_return_sequences": 1,
    }
    if stop_token_ids:
        generate_kwargs['stopping_criteria'] = StoppingCriteriaList([
            StopTokenIdsCriteria(stop_token_ids),
        ])
    
    output_ids = base_llm.generate(**generate_kwargs)
    output_ids_wo_prompt = output_ids[0, input_ids.shape[1]:]
    decoded_output = base_llm_tokenizer.decode(output_ids_wo_prompt, skip_special_tokens=True)
    if stop_str:
        pos = decoded_output.find(stop_str)
        if pos != -1:
            decoded_output = decoded_output[:pos]
    return decoded_output

def llms_generate(
    base_llm_names, instruction, input, 
    max_new_tokens, top_p=1.0, temperature=0.7,
):  
    return {
        base_llm_name: llm_generate(
        base_llm_name, instruction, input, max_new_tokens, top_p, temperature)
        for base_llm_name in base_llm_names
    }