Spaces:
Running
Running
from typing import Dict, Optional, Tuple, OrderedDict | |
from transformers import CLIPTextConfig | |
from diffusers import UNet2DConditionModel | |
import torch | |
from optimum.exporters.onnx.model_configs import VisionOnnxConfig, NormalizedConfig, DummyVisionInputGenerator, DummyTimestepInputGenerator, DummySeq2SeqDecoderTextInputGenerator, DummySeq2SeqDecoderTextInputGenerator | |
from optimum.exporters.openvino import main_export | |
from optimum.utils.input_generators import DummyInputGenerator, DEFAULT_DUMMY_SHAPES | |
from optimum.utils.normalized_config import NormalizedTextConfig | |
# IMPORTANT: You need to specify some scheduler in downloaded model cache folder to avoid errors | |
class CustomDummyTimestepInputGenerator(DummyInputGenerator): | |
""" | |
Generates dummy time step inputs. | |
""" | |
SUPPORTED_INPUT_NAMES = ( | |
"timestep", | |
"timestep_cond", | |
"text_embeds", | |
"time_ids", | |
) | |
def __init__( | |
self, | |
task: str, | |
normalized_config: NormalizedConfig, | |
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], | |
time_cond_proj_dim: int = 256, | |
random_batch_size_range: Optional[Tuple[int, int]] = None, | |
**kwargs, | |
): | |
self.task = task | |
self.vocab_size = normalized_config.vocab_size | |
self.text_encoder_projection_dim = normalized_config.text_encoder_projection_dim | |
self.time_ids = 5 if normalized_config.requires_aesthetics_score else 6 | |
if random_batch_size_range: | |
low, high = random_batch_size_range | |
self.batch_size = random.randint(low, high) | |
else: | |
self.batch_size = batch_size | |
self.time_cond_proj_dim = normalized_config.get("time_cond_proj_dim", time_cond_proj_dim) | |
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): | |
shape = [self.batch_size] | |
if input_name == "timestep": | |
return self.random_int_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=int_dtype) | |
if input_name == "timestep_cond": | |
shape.append(self.time_cond_proj_dim) | |
return self.random_float_tensor(shape, min_value=-1.0, max_value=1.0, framework=framework, dtype=float_dtype) | |
shape.append(self.text_encoder_projection_dim if input_name == "text_embeds" else self.time_ids) | |
return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype) | |
class LCMUNetOnnxConfig(VisionOnnxConfig): | |
ATOL_FOR_VALIDATION = 1e-3 | |
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu | |
# operator support, available since opset 14 | |
DEFAULT_ONNX_OPSET = 14 | |
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( | |
image_size="sample_size", | |
num_channels="in_channels", | |
hidden_size="cross_attention_dim", | |
vocab_size="norm_num_groups", | |
allow_new=True, | |
) | |
DUMMY_INPUT_GENERATOR_CLASSES = ( | |
DummyVisionInputGenerator, | |
CustomDummyTimestepInputGenerator, | |
DummySeq2SeqDecoderTextInputGenerator, | |
) | |
def inputs(self) -> Dict[str, Dict[int, str]]: | |
common_inputs = OrderedDict({ | |
"sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, | |
"timestep": {0: "steps"}, | |
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, | |
"timestep_cond": {0: "batch_size"}, | |
}) | |
# TODO : add text_image, image and image_embeds | |
if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": | |
common_inputs["text_embeds"] = {0: "batch_size"} | |
common_inputs["time_ids"] = {0: "batch_size"} | |
return common_inputs | |
def outputs(self) -> Dict[str, Dict[int, str]]: | |
return { | |
"out_sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, | |
} | |
def torch_to_onnx_output_map(self) -> Dict[str, str]: | |
return { | |
"sample": "out_sample", | |
} | |
def generate_dummy_inputs(self, framework: str = "pt", **kwargs): | |
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) | |
dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0] | |
if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": | |
dummy_inputs["added_cond_kwargs"] = { | |
"text_embeds": dummy_inputs.pop("text_embeds"), | |
"time_ids": dummy_inputs.pop("time_ids"), | |
} | |
return dummy_inputs | |
def ordered_inputs(self, model) -> Dict[str, Dict[int, str]]: | |
return self.inputs # Breaks order if timestep_cond involved ( so just copy original one ) | |
model_id = "SimianLuo/LCM_Dreamshaper_v7" | |
text_encoder_config = CLIPTextConfig.from_pretrained(model_id, subfolder = "text_encoder") | |
unet_config = UNet2DConditionModel.from_pretrained(model_id, subfolder = "unet").config | |
unet_config.text_encoder_projection_dim = text_encoder_config.projection_dim | |
unet_config.requires_aesthetics_score = False | |
custom_onnx_configs = { | |
"unet": LCMUNetOnnxConfig(config = unet_config, task = "semantic-segmentation") | |
} | |
main_export(model_name_or_path = model_id, output = "./", task = "stable-diffusion", fp16 = False, int8 = False, custom_onnx_configs = custom_onnx_configs) | |