File size: 6,470 Bytes
3cce3b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
---
tags:
- text-generation-inference
- mistral
base_model:
- mistralai/Mistral-7B-v0.1
---
# Mistral 7B v0.1 with Key-Value-Cache enabled in ONNX fp16 format
- Model creator: [MistralAI](https://huggingface.co/mistralai)
- Original model: [MistralAi Mistral 7B v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
<!-- description start -->
## Description
This repo contains the ONNX files for the ONNX conversion of Mistral 7B v0.1 done by Esperanto Technologies.
The model is in the fp16 format and has the KVC enabled.
<!-- description end -->
## How to download ONNX model and weight files
The easiest way to obtain the model is to clone this whole repo.
Alternatively you can download the files is using the `huggingface-hub` Python library.
```shell
pip3 install huggingface-hub>=0.17.1
```
Then you can download any individual model file to the current directory, at high speed, with a command like this:
```shell
huggingface-cli download Esperanto/mistral-7b-kvc-fp16-onnx --local-dir mistral-7b-kvc-fp16-onnx --local-dir-use-symlinks False
```
For more documentation on downloading with `huggingface-cli`, please see: [HF -> Hub Python Library -> Download files -> Download from the CLI](https://huggingface.co/docs/huggingface_hub/guides/download#download-from-the-cli).
## How to run from Python code using ONNXRuntime
This model can easily be ran in a CPU using [ONNXRuntime](https://onnxruntime.ai/).
#### First install the packages
```bash
pip3 install onnx==1.16.1
pip3 install onnxruntime==1.17.1
```
#### Example code: generate text with this model
We define the loop with greedy decoding:
```python
import numpy as np
import onnxruntime
import onnx
from transformers import AutoTokenizer
def generate_text(model_path, prompt, tokenizer, max_gen_tokens, total_sequence, window, context):
model = onnx.load(model_path)
#we create the inputs for the first iteration
input_tensor = tokenizer(prompt, return_tensors="pt")
prompt_size = len(input_tensor['input_ids'][0])
actual_input = input_tensor['input_ids']
if prompt_size < window:
actual_input = np.concatenate((tokenizer.bos_token_id*np.ones([1, window - prompt_size], dtype = 'int64'),
actual_input), axis=1)
if prompt_size + max_gen_tokens > total_sequence:
print("ERROR: Longer total sequence is needed!")
return
first_attention = np.concatenate((np.zeros([1, total_sequence - window], dtype = 'int64'),
np.ones((1, window), dtype = 'int64')), axis=1)
max_gen_tokens += prompt_size #we need to generate on top of parsing the prompt
inputs_names =[node.name for node in model.graph.input]
output_names =[node.name for node in model.graph.output]
n_heads = 8 #gqa-heads of the kvc
inputs_dict = {}
inputs_dict['input_ids'] = actual_input[:, :window].reshape(1, window).numpy()
inputs_dict['attention_mask'] = first_attention
index_pos = sum(first_attention[0])
inputs_dict['position_ids'] = np.concatenate((np.zeros([1, total_sequence - index_pos], dtype = 'int64'), np.arange(index_pos, dtype = 'int64').reshape(1, index_pos)), axis=1)
inputs_dict['tree_attention'] = np.triu(-65504*np.ones(total_sequence), k= 1).astype('float16').reshape(1, 1, total_sequence, total_sequence)
for name in inputs_names:
if name == 'input_ids' or name == 'attention_mask' or name == 'position_ids' or name == 'tree_attention': continue
inputs_dict[name] = np.zeros([1, n_heads, context-window, 128], dtype="float16")
index = 0
new_token = np.array([10])
next_index = window
old_j = 0
total_input = actual_input.numpy()
rt_session = onnxruntime.InferenceSession(model_path)
## We run the inferences
while next_index < max_gen_tokens:
if new_token.any() == tokenizer.eos_token_id:
break
#inference
output = rt_session.run(output_names, inputs_dict)
outs_dictionary = {name: content for (name, content) in zip (output_names, output)}
#we prepare the inputs for the next inference
for name in inputs_names:
if name == 'input_ids':
old_j = next_index
if next_index < prompt_size:
if prompt_size - next_index >= window: next_index += window
else: next_index = prompt_size
j = next_index - window
else:
next_index +=1
j = next_index - window
new_token = outs_dictionary['logits'].argmax(-1).reshape(1, window)
total_input = np.concatenate((total_input, new_token[: , -1:]), axis = 1)
inputs_dict['input_ids']= total_input[:, j:next_index].reshape(1, window)
elif name == 'attention_mask':
inputs_dict['attention_mask'] = np.concatenate((np.zeros((1, total_sequence-next_index), dtype = 'int64'), np.ones((1, next_index), dtype = 'int64')), axis=1)
elif name == 'position_ids':
inputs_dict['position_ids'] = np.concatenate((np.zeros([1, total_sequence - next_index], dtype = 'int64'), np.arange(next_index, dtype = 'int64').reshape(1, next_index)), axis=1)
elif name == 'tree_attention': continue
else:
old_name = name.replace("past_key_values", "present")
inputs_dict[name] = outs_dictionary[old_name][:, :, next_index-old_j:context-window+(next_index - old_j), :]
answer = tokenizer.decode(total_input[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
return answer
```
We now run the inferences:
```python
tokenizer = AutoTokenizer.from_pretrained("Esperanto/mistral-7b-kvc-fp16-onnx")
model_path = "mistral-7b-kvc-fp16-onnx/model.onnx"
max_gen_tokens = 20 #number of tokens we want tog eneral
total_sequence = 128 #total sequence_length
context = 1024 #the context to extend the kvc
window = 16 #number of tokens we want to parse at the time
messages = [
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
{"role": "user", "content": "Who are you?"},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
generated = generate_text(model_path, prompt, tokenizer, max_gen_tokens, total_sequence, window, context)
print(generated)
``` |