|
import streamlit as st |
|
from huggingface_hub import HfApi |
|
import os |
|
import subprocess |
|
|
|
HF_TOKEN = st.secrets.get("HF_TOKEN") or os.environ.get("HF_TOKEN") |
|
HF_USERNAME = ( |
|
st.secrets.get("HF_USERNAME") |
|
or os.environ.get("HF_USERNAME") |
|
or os.environ.get("SPACE_AUTHOR_NAME") |
|
) |
|
TRANSFORMERS_REPOSITORY_URL = "https://github.com/xenova/transformers.js" |
|
TRANSFORMERS_REPOSITORY_REVISION = "2.16.0" |
|
TRANSFORMERS_REPOSITORY_PATH = "./transformers.js" |
|
HF_BASE_URL = "https://huggingface.co" |
|
|
|
if not os.path.exists(TRANSFORMERS_REPOSITORY_PATH): |
|
os.system(f"git clone {TRANSFORMERS_REPOSITORY_URL} {TRANSFORMERS_REPOSITORY_PATH}") |
|
|
|
os.system( |
|
f"cd {TRANSFORMERS_REPOSITORY_PATH} && git checkout {TRANSFORMERS_REPOSITORY_REVISION}" |
|
) |
|
|
|
st.write("## Convert a HuggingFace model to ONNX") |
|
|
|
input_model_id = st.text_input( |
|
"Enter the HuggingFace model ID to convert. Example: `EleutherAI/pythia-14m`" |
|
) |
|
|
|
if input_model_id: |
|
model_name = ( |
|
input_model_id.replace(f"{HF_BASE_URL}/", "") |
|
.replace("/", "-") |
|
.replace(f"{HF_USERNAME}-", "") |
|
.strip() |
|
) |
|
output_model_id = f"{HF_USERNAME}/{model_name}-ONNX" |
|
output_model_url = f"{HF_BASE_URL}/{output_model_id}" |
|
api = HfApi(token=HF_TOKEN) |
|
repo_exists = api.repo_exists(output_model_id) |
|
|
|
if repo_exists: |
|
st.write("This model has already been converted! 🎉") |
|
st.link_button(f"Go to {output_model_id}", output_model_url, type="primary") |
|
else: |
|
st.write(f"This model will be converted and uploaded to the following URL:") |
|
st.code(output_model_url, language="plaintext") |
|
start_conversion = st.button(label="Proceed", type="primary") |
|
|
|
if start_conversion: |
|
with st.spinner("Converting model..."): |
|
output = subprocess.run( |
|
[ |
|
"python", |
|
"-m", |
|
"scripts.convert", |
|
"--quantize", |
|
"--model_id", |
|
input_model_id, |
|
], |
|
cwd=TRANSFORMERS_REPOSITORY_PATH, |
|
capture_output=True, |
|
text=True, |
|
) |
|
|
|
model_folder_path = ( |
|
f"{TRANSFORMERS_REPOSITORY_PATH}/models/{input_model_id}" |
|
) |
|
|
|
os.rename( |
|
f"{model_folder_path}/onnx/model.onnx", |
|
f"{model_folder_path}/onnx/decoder_model_merged.onnx", |
|
) |
|
os.rename( |
|
f"{model_folder_path}/onnx/model_quantized.onnx", |
|
f"{model_folder_path}/onnx/decoder_model_merged_quantized.onnx", |
|
) |
|
|
|
st.success("Conversion successful!") |
|
|
|
st.code(output.stderr) |
|
|
|
with st.spinner("Uploading model..."): |
|
repository = api.create_repo( |
|
f"{output_model_id}", exist_ok=True, private=False |
|
) |
|
|
|
upload_error_message = None |
|
|
|
try: |
|
api.upload_folder( |
|
folder_path=model_folder_path, repo_id=repository.repo_id |
|
) |
|
except Exception as e: |
|
upload_error_message = str(e) |
|
|
|
os.system(f"rm -rf {model_folder_path}") |
|
|
|
if upload_error_message: |
|
st.error(f"Upload failed: {upload_error_message}") |
|
else: |
|
st.success(f"Upload successful!") |
|
st.write("You can now go and view the model on HuggingFace!") |
|
st.link_button( |
|
f"Go to {output_model_id}", output_model_url, type="primary" |
|
) |
|
|