Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModel | |
import tempfile | |
from huggingface_hub import HfApi | |
from huggingface_hub import list_models | |
from packaging import version | |
import os | |
import spaces | |
def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str: | |
# ^ expect a gr.OAuthProfile object as input to get the user's profile | |
# if the user is not logged in, profile will be None | |
if profile is None: | |
return "Hello !" | |
return f"Hello {profile.name} !" | |
def check_model_exists(oauth_token: gr.OAuthToken | None, username, quantization_type, group_size, model_name, quantized_model_name): | |
"""Check if a model exists in the user's Hugging Face repository.""" | |
try: | |
models = list_models(author=username, token=oauth_token.token) | |
model_names = [model.id for model in models] | |
if quantized_model_name : | |
repo_name = f"{username}/{quantized_model_name}" | |
else : | |
if quantization_type == "int4_weight_only" : | |
repo_name = f"{username}/{model_name.split('/')[-1]}-torchao-{quantization_type.lower()}-gs_{group_size}" | |
else : | |
repo_name = f"{username}/{model_name.split('/')[-1]}-torchao-{quantization_type.lower()}" | |
if repo_name in model_names: | |
return f"Model '{repo_name}' already exists in your repository." | |
else: | |
return None # Model does not exist | |
except Exception as e: | |
return f"Error checking model existence: {str(e)}" | |
def create_model_card(model_name, quantization_type, group_size): | |
model_card = f"""--- | |
base_model: | |
- {model_name} | |
--- | |
# {model_name} (Quantized) | |
## Description | |
This model is a quantized version of the original model `{model_name}`. It has been quantized using {quantization_type} quantization with torchao. | |
## Quantization Details | |
- **Quantization Type**: {quantization_type} | |
- **Group Size**: {group_size if quantization_type == "int4_weight_only" else None} | |
## Usage | |
You can use this model in your applications by loading it directly from the Hugging Face Hub: | |
```python | |
from transformers import AutoModel | |
model = AutoModel.from_pretrained("{model_name}")""" | |
return model_card | |
def load_model_gpu(model_name, quantization_config, auth_token) : | |
return AutoModel.from_pretrained(model_name, torch_dtype=torch.bfloat16, quantization_config=quantization_config, use_auth_token=auth_token.token) | |
def load_model_cpu(model_name, quantization_config, auth_token) : | |
return AutoModel.from_pretrained(model_name, torch_dtype=torch.bfloat16, quantization_config=quantization_config, use_auth_token=auth_token.token) | |
def quantize_model(model_name, quantization_type, group_size=128, auth_token=None, username=None, device="cuda"): | |
print(f"Quantizing model: {quantization_type}") | |
if quantization_type == "int4_weight_only" : | |
quantization_config = TorchAoConfig(quantization_type, group_size=group_size) | |
else : | |
quantization_config = TorchAoConfig(quantization_type) | |
if device == "cuda" : | |
model = load_model_gpu(model_name, quantization_config=quantization_config, auth_token=auth_token) | |
else : | |
model = load_model_cpu(model_name, quantization_config=quantization_config, auth_token=auth_token) | |
return model | |
def save_model(model, model_name, quantization_type, group_size=128, username=None, auth_token=None, quantized_model_name=None): | |
print("Saving quantized model") | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
model.save_pretrained(tmpdirname, safe_serialization=False, use_auth_token=auth_token.token) | |
if quantized_model_name : | |
repo_name = f"{username}/{quantized_model_name}" | |
else : | |
if quantization_type == "int4_weight_only" : | |
repo_name = f"{username}/{model_name.split('/')[-1]}-torchao-{quantization_type.lower()}-gs_{group_size}" | |
else : | |
repo_name = f"{username}/{model_name.split('/')[-1]}-torchao-{quantization_type.lower()}" | |
model_card = create_model_card(repo_name, quantization_type, group_size) | |
with open(os.path.join(tmpdirname, "README.md"), "w") as f: | |
f.write(model_card) | |
# Push to Hub | |
api = HfApi(token=auth_token.token) | |
api.create_repo(repo_name, exist_ok=True) | |
api.upload_folder( | |
folder_path=tmpdirname, | |
repo_id=repo_name, | |
repo_type="model", | |
) | |
return f"https://huggingface.co/{repo_name}" | |
def quantize_and_save(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, model_name, quantization_type, group_size, quantized_model_name, device): | |
if oauth_token is None : | |
return "Error : Please Sign In to your HuggingFace account to use the quantizer" | |
if not profile: | |
return "Error: Please Sign In to your HuggingFace account to use the quantizer" | |
exists_message = check_model_exists(oauth_token, profile.username, quantization_type, group_size, model_name, quantized_model_name) | |
if exists_message : | |
return exists_message | |
if quantization_type == "int4_weight_only" and device == "cpu" : | |
return "int4_weight_only not supported on cpu" | |
# try : | |
quantized_model = quantize_model(model_name, quantization_type, group_size, oauth_token, profile.username, device) | |
return save_model(quantized_model, model_name, quantization_type, group_size, profile.username, oauth_token, quantized_model_name) | |
# except Exception as e : | |
# return e | |
with gr.Blocks(theme=gr.themes.Soft()) as app: | |
gr.Markdown( | |
""" | |
# 🚀 LLM Model Quantization App | |
Quantize your favorite Hugging Face models and save them to your profile! | |
""" | |
) | |
gr.LoginButton(elem_id="login-button", elem_classes="center-button") | |
m1 = gr.Markdown() | |
app.load(hello, inputs=None, outputs=m1) | |
with gr.Row(): | |
with gr.Column(): | |
model_name = gr.Textbox( | |
label="Model Name", | |
placeholder="e.g., meta-llama/Meta-Llama-3-8B", | |
value="meta-llama/Meta-Llama-3-8B" | |
) | |
quantization_type = gr.Dropdown( | |
label="Quantization Type", | |
choices=["int4_weight_only", "int8_weight_only", "int8_dynamic_activation_int8_weight"], | |
value="int8_weight_only" | |
) | |
group_size = gr.Number( | |
label="Group Size (only for int4_weight_only)", | |
value=128, | |
interactive=True | |
) | |
device = gr.Dropdown( | |
label="Device (int4 only works with cuda)", | |
choices=["cuda", "cpu"], | |
value="cuda" | |
) | |
quantized_model_name = gr.Textbox( | |
label="Model Name (optional : to override default)", | |
value="", | |
interactive=True | |
) | |
# with gr.Row(): | |
# username = gr.Textbox( | |
# label="Hugging Face Username", | |
# placeholder="Enter your Hugging Face username", | |
# value="", | |
# interactive=True, | |
# elem_id="username-box" | |
# ) | |
with gr.Column(): | |
quantize_button = gr.Button("Quantize and Save Model", variant="primary") | |
output_link = gr.Textbox(label="Quantized Model Link") | |
gr.Markdown( | |
""" | |
## Instructions | |
1. Login to your HuggingFace account | |
2. Enter the name of the Hugging Face LLM model you want to quantize (Make sure you have access to it) | |
3. Choose the quantization type. | |
4. Optionally, specify the group size. | |
5. Optionally, choose a custom name for the quantized model | |
6. Click "Quantize and Save Model" to start the process. | |
7. Once complete, you'll receive a link to the quantized model on Hugging Face. | |
Note: This process may take some time depending on the model size and your hardware you can check the container logs to see where are you at in the process! | |
""" | |
) | |
# Adding CSS styles for the username box | |
app.css = """ | |
#username-box { | |
background-color: #f0f8ff; /* Light color */ | |
border-radius: 8px; | |
padding: 10px; | |
} | |
""" | |
app.css = """ | |
.center-button { | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
margin: 0 auto; /* Center horizontally */ | |
} | |
""" | |
quantize_button.click( | |
fn=quantize_and_save, | |
inputs=[model_name, quantization_type, group_size, quantized_model_name, device], | |
outputs=[output_link] | |
) | |
# Launch the app | |
app.launch() |