|
|
|
# TinyLlama Inference |
|
|
|
This code demonstrates how to load and run inference using the `huzaifa1117/tinyllama_AWQ_4bit` model with quantization for efficient computation on CUDA devices. |
|
|
|
## Installation |
|
|
|
To begin, ensure you have the necessary libraries installed: |
|
|
|
```bash |
|
pip install torch transformers peft awq |
|
``` |
|
|
|
## Usage |
|
|
|
### Model Loading and Inference |
|
|
|
```python |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
HqqConfig |
|
) |
|
from peft import PeftModel |
|
import torch |
|
from awq import AutoAWQForCausalLM |
|
|
|
# Use CUDA if available |
|
device = torch.device("cuda") |
|
|
|
# Model ID and quantization configuration |
|
model_id = "huzaifa1117/tinyllama_AWQ_4bit" |
|
quant_config = HqqConfig(nbits=1, group_size=64, quant_zero=False, quant_scale=False, axis=1) |
|
|
|
# Load the tokenizer |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
# Load the model with quantization on CUDA |
|
model = AutoAWQForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, use_cache=False, device_map='cuda') |
|
|
|
# Move the model to the CUDA device |
|
model.to(device) |
|
|
|
# Tokenize input and run inference |
|
input_text = "Your input text here" |
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device) |
|
output = model.generate(input_ids, max_length=50) |
|
|
|
# Decode and print the output |
|
output_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
print(output_text) |
|
``` |
|
|
|
### Notes |
|
|
|
- This setup is designed for efficient computation using quantization, reducing model size and computation cost. |
|
- Ensure that you have a CUDA-capable GPU for running this code efficiently. |
|
|
|
## Quantization |
|
|
|
The model uses `HqqConfig` to apply 1-bit quantization for all linear layers, ensuring high performance on resource-constrained hardware: |
|
|
|
```python |
|
quant_config = HqqConfig(nbits=1, group_size=64, quant_zero=False, quant_scale=False, axis=1) |
|
``` |
|
|
|
## License |
|
|
|
This project is licensed under the terms of the MIT license. |