Spaces:
Sleeping
Sleeping
import torch | |
import torch.onnx | |
import onnx | |
from VitsModelSplit.vits_model_only_d import Vits_models_only_decoder | |
from VitsModelSplit.vits_model import VitsModel | |
import gradio as gr | |
import os | |
from transformers import AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained("wasmdashai/vits-ar") | |
def create_file(file_path): | |
# ู ุณุงุฑ ุงูู ูู ุงูู ุคูุช | |
#file_path = "DDFGDdd.onnx" | |
# ุฅูุดุงุก ู ูู ONNX ุชุฌุฑูุจู ูู ุญุงูุฉ ุนุฏู ูุฌูุฏู | |
if not os.path.exists(file_path): | |
#with open(file_path, "w") as file: | |
#file.write("This is a test ONNX model file.") | |
return None | |
# ุฅุฑุฌุงุน ู ุณุงุฑ ุงูู ูู ุญุชู ูู ูู ุชูุฒููู | |
return file_path | |
class OnnxModelConverter: | |
def __init__(self): | |
self.model = None | |
def download_file(self,file_path): | |
if not os.path.exists(file_path): | |
#with open(file_path, "w") as file: | |
#file.write("This is a test ONNX model.") | |
return None | |
return file_path | |
def convert(self, model_name, token, onnx_filename, conversion_type): | |
""" | |
Main function to handle different types of model conversions. | |
Args: | |
model_name (str): Name of the model to convert. | |
token (str): Access token for loading the model. | |
onnx_filename (str): Desired filename for the ONNX output. | |
conversion_type (str): Type of conversion ('decoder', 'only_decoder', or 'full_model'). | |
Returns: | |
str: The path to the generated ONNX file. | |
""" | |
if conversion_type == "decoder": | |
return self.convert_decoder(model_name, token, onnx_filename) | |
elif conversion_type == "only_decoder": | |
return self.convert_only_decoder(model_name, token, onnx_filename) | |
elif conversion_type == "full_model": | |
return self.convert_full_model(model_name, token, onnx_filename) | |
else: | |
raise ValueError("Invalid conversion type. Choose from 'decoder', 'only_decoder', or 'full_model'.") | |
def convert_decoder(self, model_name, token, onnx_filename): | |
""" | |
Converts only the decoder part of the Vits model to ONNX format. | |
Args: | |
model_name (str): Name of the model to convert. | |
token (str): Access token for loading the model. | |
onnx_filename (str): Desired filename for the ONNX output. | |
Returns: | |
str: The path to the generated ONNX file. | |
""" | |
model = VitsModel.from_pretrained(model_name, token=token) | |
onnx_file = f"/tmp/{onnx_filename}.onnx" | |
example_input = torch.randn(1, 192, 10) | |
torch.onnx.export( | |
model.decoder, | |
example_input, | |
onnx_file, | |
opset_version=11, | |
input_names=['input'], | |
output_names=['output'], | |
dynamic_axes={"input": {0: "batch_size", 2: "seq_len"}, | |
"output": {0: "batch_size", 1: "sequence_length"}} | |
) | |
return self.download_file(onnx_file) | |
def convert_only_decoder(self, model_name, token, onnx_filename): | |
""" | |
Converts only the decoder part of the Vits model to ONNX format. | |
Args: | |
model_name (str): Name of the model to convert. | |
token (str): Access token for loading the model. | |
onnx_filename (str): Desired filename for the ONNX output. | |
Returns: | |
str: The path to the generated ONNX file. | |
""" | |
model = Vits_models_only_decoder.from_pretrained(model_name, token=token) | |
onnx_file = f"/tmp/{onnx_filename}.onnx" | |
inputs = tokenizer("ุงูุณูุงู ุนูููู ููู ุงูุญุงู", return_tensors="pt") | |
# Trace the decoder part of the model | |
example_inputs = inputs.input_ids.type(torch.LongTensor) | |
torch.onnx.export(model, | |
example_inputs, | |
onnx_file, | |
input_names=["input"], | |
output_names=["output"], | |
dynamic_axes={"input": {0: "batch_size", 1: "sequence_length"}, | |
"output": {0: "batch_size", 1: "sequence_length"}}) | |
return self.download_file(onnx_file) | |
def convert_full_model(self, model_name, token, onnx_filename): | |
""" | |
Converts the full Vits model (including encoder and decoder) to ONNX format. | |
Args: | |
model_name (str): Name of the model to convert. | |
token (str): Access token for loading the model. | |
onnx_filename (str): Desired filename for the ONNX output. | |
Returns: | |
str: The path to the generated ONNX file. | |
""" | |
model = VitsModel.from_pretrained(model_name, token=token) | |
onnx_file = f"/tmp/{onnx_filename}.onnx" | |
vocab_size = model.text_encoder.embed_tokens.weight.size(0) | |
example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long) | |
torch.onnx.export( | |
model, | |
example_input, | |
onnx_file, | |
opset_version=11, | |
input_names=['input'], | |
output_names=['output'], | |
dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size'}} | |
) | |
return self.download_file(onnx_file) | |
def starrt(self): | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
text_n_model=gr.Textbox(label="name model") | |
text_n_token=gr.Textbox(label="token") | |
text_n_onxx=gr.Textbox(label="name model onxx") | |
choice = gr.Dropdown(choices=["decoder", "only_decoder", "full_model"], label="My Dropdown") | |
with gr.Column(): | |
btn=gr.Button("convert") | |
label=gr.Label("return name model onxx") | |
btn.click(self.convert,[text_n_model,text_n_token,text_n_onxx,choice],[gr.File(label="Download ONNX File")]) | |
btx=gr.Textbox("namefile") | |
download_button1=gr.Button("send") | |
download_button = gr.File(label="Download ONNX File") | |
download_button1.click(create_file,[btx],[download_button]) | |
#choice.change(fn=function_change, inputs=choice, outputs=label) | |
return demo | |
c=OnnxModelConverter() | |
cc=c.starrt() | |
cc.launch(share=True) | |