convert-to-onnx / app.py
pdufour's picture
Update app.py
ca68348 verified
raw
history blame
4.73 kB
import sys
import os
import urllib.request
import subprocess
import tarfile
import tempfile
import streamlit as st
from huggingface_hub import HfApi
HF_TOKEN = st.secrets.get("HF_TOKEN") or os.environ.get("HF_TOKEN")
HF_USERNAME = (
st.secrets.get("HF_USERNAME")
or os.environ.get("HF_USERNAME")
or os.environ.get("SPACE_AUTHOR_NAME")
)
TRANSFORMERS_BASE_URL = "https://github.com/xenova/transformers.js/archive/refs"
TRANSFORMERS_REPOSITORY_REVISION = "3.0.0"
TRANSFORMERS_REF_TYPE = "tags" if urllib.request.urlopen(f"{TRANSFORMERS_BASE_URL}/tags/{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz").getcode() == 200 else "heads"
TRANSFORMERS_REPOSITORY_URL = f"{TRANSFORMERS_BASE_URL}/{TRANSFORMERS_REF_TYPE}/{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
TRANSFORMERS_REPOSITORY_PATH = "./transformers.js"
ARCHIVE_PATH = f"./transformers_{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
HF_BASE_URL = "https://huggingface.co"
if not os.path.exists(TRANSFORMERS_REPOSITORY_PATH):
urllib.request.urlretrieve(TRANSFORMERS_REPOSITORY_URL, ARCHIVE_PATH)
with tempfile.TemporaryDirectory() as tmp_dir:
with tarfile.open(ARCHIVE_PATH, "r:gz") as tar:
tar.extractall(tmp_dir)
extracted_folder = os.path.join(tmp_dir, os.listdir(tmp_dir)[0])
os.rename(extracted_folder, TRANSFORMERS_REPOSITORY_PATH)
os.remove(ARCHIVE_PATH)
print("Repository downloaded and extracted successfully.")
st.write("## Convert a HuggingFace model to ONNX")
input_model_id = st.text_input(
"Enter the HuggingFace model ID to convert. Example: `EleutherAI/pythia-14m`"
)
if input_model_id:
model_name = (
input_model_id.replace(f"{HF_BASE_URL}/", "")
.replace("/", "-")
.replace(f"{HF_USERNAME}-", "")
.strip()
)
output_model_id = f"{HF_USERNAME}/{model_name}-ONNX"
output_model_url = f"{HF_BASE_URL}/{output_model_id}"
api = HfApi(token=HF_TOKEN)
repo_exists = api.repo_exists(output_model_id)
if repo_exists:
st.write("This model has already been converted! 🎉")
st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
else:
st.write(f"This model will be converted and uploaded to the following URL:")
st.code(output_model_url, language="plaintext")
start_conversion = st.button(label="Proceed", type="primary")
if start_conversion:
with st.spinner("Converting model..."):
output = subprocess.run(
[
"python",
"-m",
"scripts.convert",
"--quantize",
"--model_id",
input_model_id,
],
cwd=TRANSFORMERS_REPOSITORY_PATH,
capture_output=True,
text=True,
)
# Log the script output
print("### Script Output ###")
print(output.stdout)
# Log any errors
if output.stderr:
print("### Script Errors ###")
print(output.stderr)
model_folder_path = (
f"{TRANSFORMERS_REPOSITORY_PATH}/models/{input_model_id}"
)
os.rename(
f"{model_folder_path}/onnx/model.onnx",
f"{model_folder_path}/onnx/decoder_model_merged.onnx",
)
os.rename(
f"{model_folder_path}/onnx/model_quantized.onnx",
f"{model_folder_path}/onnx/decoder_model_merged_quantized.onnx",
)
st.success("Conversion successful!")
st.code(output.stderr)
with st.spinner("Uploading model..."):
repository = api.create_repo(
f"{output_model_id}", exist_ok=True, private=False
)
upload_error_message = None
try:
api.upload_folder(
folder_path=model_folder_path, repo_id=repository.repo_id
)
except Exception as e:
upload_error_message = str(e)
os.system(f"rm -rf {model_folder_path}")
if upload_error_message:
st.error(f"Upload failed: {upload_error_message}")
else:
st.success(f"Upload successful!")
st.write("You can now go and view the model on HuggingFace!")
st.link_button(
f"Go to {output_model_id}", output_model_url, type="primary"
)