|
|
|
|
|
|
|
from typing import List |
|
from rknn.api import RKNN |
|
from math import exp |
|
from sys import exit |
|
import argparse |
|
|
|
|
|
def convert_pipeline_component(onnx_path: str, resolution_list: List[List[int]], target_platform: str = 'rk3588'): |
|
print(f'Converting {onnx_path} to RKNN model') |
|
print(f'with target platform {target_platform}') |
|
print(f'with resolutions:') |
|
for res in resolution_list: |
|
print(f'- {res[0]}x{res[1]}') |
|
use_dynamic_shape = False |
|
if(len(resolution_list) > 1): |
|
print("Warning: RKNN dynamic shape support is probably broken, may throw errors") |
|
use_dynamic_shape = True |
|
|
|
batch_size = 1 |
|
LATENT_RESIZE_FACTOR = 8 |
|
|
|
if "text_encoder" in onnx_path: |
|
input_size_list = [[[1,77]]] |
|
inputs=['input_ids'] |
|
use_dynamic_shape = False |
|
elif "unet" in onnx_path: |
|
|
|
|
|
input_size_list = [] |
|
for res in resolution_list: |
|
input_size_list.append( |
|
[[1,4, res[0]//LATENT_RESIZE_FACTOR, res[1]//LATENT_RESIZE_FACTOR], |
|
[1], |
|
[1, 77, 768], |
|
[1, 256]] |
|
) |
|
inputs=['sample','timestep','encoder_hidden_states','timestep_cond'] |
|
elif "vae_decoder" in onnx_path: |
|
input_size_list = [] |
|
for res in resolution_list: |
|
input_size_list.append( |
|
[[1,4, res[0]//LATENT_RESIZE_FACTOR, res[1]//LATENT_RESIZE_FACTOR]] |
|
) |
|
inputs=['latent_sample'] |
|
else: |
|
print("Unknown component: ", onnx_path) |
|
exit(1) |
|
|
|
rknn = RKNN(verbose=True) |
|
|
|
|
|
print('--> Config model') |
|
rknn.config(target_platform='rk3588', optimization_level=3, single_core_mode=True, |
|
dynamic_input= input_size_list if use_dynamic_shape else None) |
|
print('done') |
|
|
|
|
|
print('--> Loading model') |
|
ret = rknn.load_onnx(model=onnx_path, |
|
inputs=None if use_dynamic_shape else inputs, |
|
input_size_list= None if use_dynamic_shape else input_size_list[0]) |
|
if ret != 0: |
|
print('Load model failed!') |
|
exit(ret) |
|
print('done') |
|
|
|
|
|
print('--> Building model') |
|
ret = rknn.build(do_quantization=False, rknn_batch_size=batch_size) |
|
if ret != 0: |
|
print('Build model failed!') |
|
exit(ret) |
|
print('done') |
|
|
|
|
|
print('--> Export RKNN model') |
|
ret = rknn.export_rknn(onnx_path.replace('.onnx', '.rknn')) |
|
if ret != 0: |
|
print('Export RKNN model failed!') |
|
exit(ret) |
|
print('done') |
|
|
|
rknn.release() |
|
print('RKNN model is converted successfully!') |
|
|
|
|
|
def parse_resolution_list(resolution: str) -> List[List[int]]: |
|
resolution_pairs = resolution.split(',') |
|
parsed_resolutions = [] |
|
for pair in resolution_pairs: |
|
width, height = map(int, pair.split('x')) |
|
parsed_resolutions.append([width, height]) |
|
|
|
return parsed_resolutions |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='Convert Stable Diffusion ONNX models to RKNN models') |
|
parser.add_argument('-m','--model-dir', type=str, help='Directory containing the Stable Diffusion ONNX models', required=True) |
|
parser.add_argument('-c','--components', type=str, help='Name of the components to convert, e.g. "text_encoder,unet,vae_decoder"', default='text_encoder, unet, vae_decoder') |
|
parser.add_argument('-r','--resolutions', type=str, help='Comma-separated list of resolutions for the model, e.g. "256x256,512x512"', default='256x256') |
|
parser.add_argument('--target_platform', type=str, help='Target platform for the RKNN model, default is "rk3588"', default='rk3588') |
|
args = parser.parse_args() |
|
|
|
components = args.components.split(',') |
|
|
|
for component in components: |
|
onnx_path = f'{args.model_dir}/{component.strip()}/model.onnx' |
|
resolution_list = parse_resolution_list(args.resolutions) |
|
if(len(resolution_list) == 0): |
|
print("Error: No resolutions specified") |
|
exit(1) |
|
|
|
convert_pipeline_component(onnx_path, resolution_list, args.target_platform) |
|
|
|
|
|
|
|
|