File size: 2,255 Bytes
981fc9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
---
license: apache-2.0
---
这是基于Auto-GPTQ框架的量化模型,模型选取为huatuoGPT2-7B,这是一个微调模型,基底模型为百川-7B。
参数说明: 原模型大小:16GB,量化后模型大小:5GB
推理准确度尚未测试,请谨慎使用
量化过程中,校准数据采用微调训练集Medical Fine-tuning Instruction (GPT-4)。
使用示例(目前仅支持awq,transformers的集成尚在研究):
开始之前务必指定GPU
```
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
```
确保你安装了auto-awq
```
!git clone https://github.com/casper-hansen/AutoAWQ
cd AutoAWQ
!pip install -e .
```
```
from awq import AutoAWQForCausalLM
from awq.utils.utils import get_best_device
from transformers import AutoTokenizer, TextStreamer
quant_path = "jiangchengchengNLP/huatuo_AutoAWQ_7B4bits"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path,device="cuda",fuse_layers=False)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?"
chat = [
{"role": "user", "content": prompt},
]
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
tokenizer.chat_template="""
{%- for message in messages -%}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{%- endif -%}
{%- if message['role'] == 'user' -%}
{{ '<问>:' + message['content'] + '\n' }}
{%- elif message['role'] == 'assistant' -%}
{{ '<答>:' + message['content'] + '\n' }}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- '<答>:' -}}
{% endif %}
"""
tokens = tokenizer.apply_chat_template(
chat,
return_tensors="pt"
)
tokens = tokens.to("cuda:0")
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=1000,
eos_token_id=terminators,
max_length=1000,
)
``` |