ericsorides's picture
Added new inputs and README
1a44caa
|
raw
history blame
6.88 kB
---
tags:
- text-generation-inference
- mistral
- 4-bit precision
- AWQ
base_model:
- mistralai/Mistral-7B-v0.1
---
# Mistral 7B v0.1 with Key-Value-Cache enabled in ONNX AWQ (4-bit) format
- Model creator: [MistralAI](https://huggingface.co/mistralai)
- Original model: [MistralAi Mistral 7B v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
<!-- description start -->
## Description
This repo contains the ONNX files for the ONNX conversion of Mistral 7B v0.1 done by Esperanto Technologies.
The model is in the 4-bit format quantized with AWQ and has the KVC enabled.
### About AWQ
AWQ is an efficient, accurate and blazing-fast low-bit weight quantization method, currently supporting 4-bit quantization. Compared to GPTQ, it offers faster Transformers-based inference with equivalent or better quality compared to the most commonly used GPTQ settings.
More here: [AutoAWQ](https://github.com/casper-hansen/AutoAWQ)
<!-- description end -->
## How to download ONNX model and weight files
The easiest way to obtain the model is to clone this whole repo.
Alternatively you can download the files is using the `huggingface-hub` Python library.
```shell
pip3 install huggingface-hub>=0.17.1
```
Then you can download any individual model file to the current directory, at high speed, with a command like this:
```shell
huggingface-cli download Esperanto/mistral-7b-kvc-AWQ-int4-onnx --local-dir mistral-7b-kvc-AWQ-int4-onnx --local-dir-use-symlinks False
```
For more documentation on downloading with `huggingface-cli`, please see: [HF -> Hub Python Library -> Download files -> Download from the CLI](https://huggingface.co/docs/huggingface_hub/guides/download#download-from-the-cli).
## How to run from Python code using ONNXRuntime
This model can easily be ran in a CPU using [ONNXRuntime](https://onnxruntime.ai/).
#### First install the packages
```bash
pip3 install onnx==1.16.1
pip3 install onnxruntime==1.17.1
```
#### Example code: generate text with this model
We define the loop with greedy decoding:
```python
import numpy as np
import onnxruntime
import onnx
from transformers import AutoTokenizer
def generate_text(model_path, prompt, tokenizer, max_gen_tokens, total_sequence, window, context):
model = onnx.load(model_path)
#we create the inputs for the first iteration
input_tensor = tokenizer(prompt, return_tensors="pt")
prompt_size = len(input_tensor['input_ids'][0])
actual_input = input_tensor['input_ids']
if prompt_size < window:
actual_input = np.concatenate((tokenizer.bos_token_id*np.ones([1, window - prompt_size], dtype = 'int64'),
actual_input), axis=1)
if prompt_size + max_gen_tokens > total_sequence:
print("ERROR: Longer total sequence is needed!")
return
first_attention = np.concatenate((np.zeros([1, total_sequence - window], dtype = 'int64'),
np.ones((1, window), dtype = 'int64')), axis=1)
max_gen_tokens += prompt_size #we need to generate on top of parsing the prompt
inputs_names =[node.name for node in model.graph.input]
output_names =[node.name for node in model.graph.output]
n_heads = 8 #gqa-heads of the kvc
inputs_dict = {}
inputs_dict['input_ids'] = actual_input[:, :window].reshape(1, window).numpy()
inputs_dict['attention_mask'] = first_attention
index_pos = sum(first_attention[0])
inputs_dict['position_ids'] = np.concatenate((np.zeros([1, total_sequence - index_pos], dtype = 'int64'), np.arange(index_pos, dtype = 'int64').reshape(1, index_pos)), axis=1)
inputs_dict['tree_attention'] = np.triu(-65504*np.ones(total_sequence), k= 1).astype('float16').reshape(1, 1, total_sequence, total_sequence)
for name in inputs_names:
if name == 'input_ids' or name == 'attention_mask' or name == 'position_ids' or name == 'tree_attention': continue
inputs_dict[name] = np.zeros([1, n_heads, context-window, 128], dtype="float16")
index = 0
new_token = np.array([10])
next_index = window
old_j = 0
total_input = actual_input.numpy()
rt_session = onnxruntime.InferenceSession(model_path)
## We run the inferences
while next_index < max_gen_tokens:
if new_token.any() == tokenizer.eos_token_id:
break
#inference
output = rt_session.run(output_names, inputs_dict)
outs_dictionary = {name: content for (name, content) in zip (output_names, output)}
#we prepare the inputs for the next inference
for name in inputs_names:
if name == 'input_ids':
old_j = next_index
if next_index < prompt_size:
if prompt_size - next_index >= window: next_index += window
else: next_index = prompt_size
j = next_index - window
else:
next_index +=1
j = next_index - window
new_token = outs_dictionary['logits'].argmax(-1).reshape(1, window)
total_input = np.concatenate((total_input, new_token[: , -1:]), axis = 1)
inputs_dict['input_ids']= total_input[:, j:next_index].reshape(1, window)
elif name == 'attention_mask':
inputs_dict['attention_mask'] = np.concatenate((np.zeros((1, total_sequence-next_index), dtype = 'int64'), np.ones((1, next_index), dtype = 'int64')), axis=1)
elif name == 'position_ids':
inputs_dict['position_ids'] = np.concatenate((np.zeros([1, total_sequence - next_index], dtype = 'int64'), np.arange(next_index, dtype = 'int64').reshape(1, next_index)), axis=1)
elif name == 'tree_attention': continue
else:
old_name = name.replace("past_key_values", "present")
inputs_dict[name] = outs_dictionary[old_name][:, :, next_index-old_j:context-window+(next_index - old_j), :]
answer = tokenizer.decode(total_input[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
return answer
```
We now run the inferences:
```python
tokenizer = AutoTokenizer.from_pretrained("Esperanto/mistral-7b-kvc-AWQ-int4-onnx")
model_path = "mistral-7b-kvc-AWQ-int4-onnx/model.onnx"
max_gen_tokens = 20 #number of tokens we want tog eneral
total_sequence = 128 #total sequence_length
context = 1024 #the context to extend the kvc
window = 16 #number of tokens we want to parse at the time
messages = [
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
{"role": "user", "content": "Who are you?"},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
generated = generate_text(model_path, prompt, tokenizer, max_gen_tokens, total_sequence, window, context)
print(generated)
```