ONNXServies / app.py
wasmdashai's picture
Update app.py
93d0c27 verified
import torch
import torch.onnx
import onnx
from VitsModelSplit.vits_model_only_d import Vits_models_only_decoder
from VitsModelSplit.vits_model import VitsModel
import gradio as gr
import os
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("wasmdashai/vits-ar")
def create_file(file_path):
# ู…ุณุงุฑ ุงู„ู…ู„ู ุงู„ู…ุคู‚ุช
#file_path = "DDFGDdd.onnx"
# ุฅู†ุดุงุก ู…ู„ู ONNX ุชุฌุฑูŠุจูŠ ููŠ ุญุงู„ุฉ ุนุฏู… ูˆุฌูˆุฏู‡
if not os.path.exists(file_path):
#with open(file_path, "w") as file:
#file.write("This is a test ONNX model file.")
return None
# ุฅุฑุฌุงุน ู…ุณุงุฑ ุงู„ู…ู„ู ุญุชู‰ ูŠู…ูƒู† ุชู†ุฒูŠู„ู‡
return file_path
class OnnxModelConverter:
def __init__(self):
self.model = None
def download_file(self,file_path):
if not os.path.exists(file_path):
#with open(file_path, "w") as file:
#file.write("This is a test ONNX model.")
return None
return file_path
def convert(self, model_name, token, onnx_filename, conversion_type):
"""
Main function to handle different types of model conversions.
Args:
model_name (str): Name of the model to convert.
token (str): Access token for loading the model.
onnx_filename (str): Desired filename for the ONNX output.
conversion_type (str): Type of conversion ('decoder', 'only_decoder', or 'full_model').
Returns:
str: The path to the generated ONNX file.
"""
if conversion_type == "decoder":
return self.convert_decoder(model_name, token, onnx_filename)
elif conversion_type == "only_decoder":
return self.convert_only_decoder(model_name, token, onnx_filename)
elif conversion_type == "full_model":
return self.convert_full_model(model_name, token, onnx_filename)
else:
raise ValueError("Invalid conversion type. Choose from 'decoder', 'only_decoder', or 'full_model'.")
def convert_decoder(self, model_name, token, onnx_filename):
"""
Converts only the decoder part of the Vits model to ONNX format.
Args:
model_name (str): Name of the model to convert.
token (str): Access token for loading the model.
onnx_filename (str): Desired filename for the ONNX output.
Returns:
str: The path to the generated ONNX file.
"""
model = VitsModel.from_pretrained(model_name, token=token)
onnx_file = f"/tmp/{onnx_filename}.onnx"
example_input = torch.randn(1, 192, 10)
torch.onnx.export(
model.decoder,
example_input,
onnx_file,
opset_version=11,
input_names=['input'],
output_names=['output'],
dynamic_axes={"input": {0: "batch_size", 2: "seq_len"},
"output": {0: "batch_size", 1: "sequence_length"}}
)
return self.download_file(onnx_file)
def convert_only_decoder(self, model_name, token, onnx_filename):
"""
Converts only the decoder part of the Vits model to ONNX format.
Args:
model_name (str): Name of the model to convert.
token (str): Access token for loading the model.
onnx_filename (str): Desired filename for the ONNX output.
Returns:
str: The path to the generated ONNX file.
"""
model = Vits_models_only_decoder.from_pretrained(model_name, token=token)
onnx_file = f"/tmp/{onnx_filename}.onnx"
inputs = tokenizer("ุงู„ุณู„ุงู… ุนู„ูŠูƒู… ูƒูŠู ุงู„ุญุงู„", return_tensors="pt")
# Trace the decoder part of the model
example_inputs = inputs.input_ids.type(torch.LongTensor)
torch.onnx.export(model,
example_inputs,
onnx_file,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size", 1: "sequence_length"},
"output": {0: "batch_size", 1: "sequence_length"}})
return self.download_file(onnx_file)
def convert_full_model(self, model_name, token, onnx_filename):
"""
Converts the full Vits model (including encoder and decoder) to ONNX format.
Args:
model_name (str): Name of the model to convert.
token (str): Access token for loading the model.
onnx_filename (str): Desired filename for the ONNX output.
Returns:
str: The path to the generated ONNX file.
"""
model = VitsModel.from_pretrained(model_name, token=token)
onnx_file = f"/tmp/{onnx_filename}.onnx"
vocab_size = model.text_encoder.embed_tokens.weight.size(0)
example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
torch.onnx.export(
model,
example_input,
onnx_file,
opset_version=11,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size'}}
)
return self.download_file(onnx_file)
def starrt(self):
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
text_n_model=gr.Textbox(label="name model")
text_n_token=gr.Textbox(label="token")
text_n_onxx=gr.Textbox(label="name model onxx")
choice = gr.Dropdown(choices=["decoder", "only_decoder", "full_model"], label="My Dropdown")
with gr.Column():
btn=gr.Button("convert")
label=gr.Label("return name model onxx")
btn.click(self.convert,[text_n_model,text_n_token,text_n_onxx,choice],[gr.File(label="Download ONNX File")])
btx=gr.Textbox("namefile")
download_button1=gr.Button("send")
download_button = gr.File(label="Download ONNX File")
download_button1.click(create_file,[btx],[download_button])
#choice.change(fn=function_change, inputs=choice, outputs=label)
return demo
c=OnnxModelConverter()
cc=c.starrt()
cc.launch(share=True)