File size: 7,220 Bytes
9f26773
 
 
63b3bac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
---
license: creativeml-openrail-m
---

This model is an ONNX Export of [Pygmalion-6b](https://huggingface.co/PygmalionAI/pygmalion-6b), all credits should go to PygmalionAI.

Be warned, this ONNX Export is not fully accurate and it is upscaled to Float32 due to limitation of PyTorch ONNX Export, this is going to occupy twice the memory as original Pygmalion AI Model, the purpose of this export is to obtain
a list of operators and nodes that can then be used to run inference on Pygmalion 6b model on Vulkan Compute eventually which would enable a no BS inference with quantization
on INT8 or INT4 while compatible on almost any devices out of the box that supports Vulkan Compute.

Here are the following scripts, model.py is obtained from [PygmalioniAI/gradio-ui](https://github.com/PygmalionAI/gradio-ui) and is Licensed under GNU Affero General Public License v3.0.
In respect to that license, all scripts listed below are under GNU Affero General Public License v3.0.


**export.py**
```py
import torch
import onnx
import transformers
import typing as t

model_name = "PygmalionAI/pygmalion-6b"
from model import build_model_and_tokenizer_for, run_raw_inference
model, tokenizer = build_model_and_tokenizer_for(model_name)
model.to('cpu').float()

input_layer = model.get_input_embeddings()
output_layer = model.get_output_embeddings()

# Load PyTorch model from .pth file
#model = AutoModelForCausalLM.from_pretrained("PygmalionAI/pygmalion-6b")

#state_dict = torch.load('pygmalion-6b.pth')

#model.load_state_dict(state_dict)

# Export PyTorch model to ONNX format
# Encode some input text
input_text = "Hello, how are you today?"
encoded_input = tokenizer.encode(input_text, return_tensors='pt')

# Export the tokenizer to ONNX format
print(f"Raw: {input_text}")
print(f"Encoded: {encoded_input}")

output_path = "onnx/pygmalion-6b.onnx"
dummy_input = torch.zeros((1, 10), dtype=torch.long)
input_names = ["input_ids"]
output_names = ["output"]
dynamic_axes = {"input_ids": {0: "batch_size", 1: "sequence_length"},
                "output": {0: "batch_size", 1: "sequence_length"}}
torch.onnx.export(model, dummy_input, output_path, input_names=input_names,
                  output_names=output_names, dynamic_axes=dynamic_axes,
                  opset_version=12)
```

**model.py**
```py
import logging
import typing as t

import torch
import transformers

logger = logging.getLogger(__name__)


def build_model_and_tokenizer_for(
    model_name: str
) -> t.Tuple[transformers.AutoModelForCausalLM, transformers.AutoTokenizer]:
    '''Sets up the model and accompanying objects.'''
    logger.info(f"Loading tokenizer for {model_name}")
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

    # NOTE(11b): non-OPT models support passing this in at inference time, might
    # be worth refactoring for a debug version so we're able to experiment on
    # the fly
    bad_words_ids = [
        tokenizer(bad_word, add_special_tokens=False).input_ids
        for bad_word in _build_bad_words_list_for(model_name)
    ]

    logger.info(f"Loading the {model_name} model")
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_name, bad_words_ids=bad_words_ids)
    model.eval().to("cpu")

    logger.info("Model and tokenizer are ready")
    return model, tokenizer

def build_tokenizer_for(
    model_name: str
) -> t.Tuple[transformers.AutoTokenizer]:
    '''Sets up the model and accompanying objects.'''
    logger.info(f"Loading tokenizer for {model_name}")
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

    # NOTE(11b): non-OPT models support passing this in at inference time, might
    # be worth refactoring for a debug version so we're able to experiment on
    # the fly
    bad_words_ids = [
        tokenizer(bad_word, add_special_tokens=False).input_ids
        for bad_word in _build_bad_words_list_for(model_name)
    ]

    return tokenizer


def run_raw_inference(model: transformers.AutoModelForCausalLM,
                      tokenizer: transformers.AutoTokenizer, prompt: str,
                      user_message: str, **kwargs: t.Any) -> str:
    '''
    Runs inference on the model, and attempts to returns only the newly
    generated text.

    :param model: Model to perform inference with.
    :param tokenizer: Tokenizer to tokenize input with.
    :param prompt: Input to feed to the model.
    :param user_message: The user's raw message, exactly as appended to the end
        of `prompt`. Used for trimming the original input from the model output.
    :return: Decoded model generation.
    '''
    tokenized_items = tokenizer(prompt, return_tensors="pt").to("cpu")

    # Atrocious code to stop generation when the model outputs "\nYou: " in
    # freshly generated text. Feel free to send in a PR if you know of a
    # cleaner way to do this.
    stopping_criteria_list = transformers.StoppingCriteriaList([
        _SentinelTokenStoppingCriteria(
            sentinel_token_ids=tokenizer(
                "\nYou:",
                add_special_tokens=False,
                return_tensors="pt",
            ).input_ids.to("cpu"),
            starting_idx=tokenized_items.input_ids.shape[-1])
    ])

    logits = model.generate(stopping_criteria=stopping_criteria_list,
                            **tokenized_items,
                            **kwargs)
    output = tokenizer.decode(logits[0], skip_special_tokens=True)

    logger.debug("Before trimming, model output was: `%s`", output)

    # Trim out the input prompt from the generated output.
    if (idx := prompt.rfind(user_message)) != -1:
        trimmed_output = output[idx + len(user_message) - 1:].strip()
        logger.debug("After trimming, it became: `%s`", trimmed_output)

        return trimmed_output
    else:
        raise Exception(
            "Couldn't find user message in the model's output. What?")


def _build_bad_words_list_for(_model_name: str) -> t.List[str]:
    '''Builds a list of bad words for the given model.'''

    # NOTE(11b): This was implemented as a function because each model size
    # seems to have it quirks at the moment, but this is a rushed implementation
    # so I'm not handling that, hence the dumb return here.
    return ["Persona:", "Scenario:", "<START>"]


#class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):

#    def __init__(self, sentinel_token_ids: torch.LongTensor,
#                 starting_idx: int):
#        transformers.StoppingCriteria.__init__(self)
#        self.sentinel_token_ids = sentinel_token_ids
#        self.starting_idx = starting_idx

#    def __call__(self, input_ids: torch.LongTensor,
#                 _scores: torch.FloatTensor) -> bool:
#        for sample in input_ids:
#            trimmed_sample = sample[self.starting_idx:]
#            # Can't unfold, output is still too tiny. Skip.
#            if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]:
#                continue

#            for window in trimmed_sample.unfold(
#                    0, self.sentinel_token_ids.shape[-1], 1):
#                if torch.all(torch.eq(self.sentinel_token_ids, window)):
#                    return True
#        return False
```