Error when doing Torch Compile
#422
by
mghaff
- opened
My question is what is the right way of using FluxPipeline with torch.compile?
Here is my attempt. I am running Flux on H100 GPUs. Below is my code:
flux_model_pipe = FluxPipeline.from_pretrained(
pretrained_model_name_or_path=model_blob_path,
torch_dtype=torch.float16,
local_files_only=True,
)
flux_model_pipe.to(device)
flux_model_pipe = torch.compile(flux_model_pipe)
Doing torch.compile reduces the performance (7s to 120s) while I expected it to make it faster.
Here is the warning I get that I think is causing this:
.venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py:679: UserWarning: Graph break due to unsupported builtin unicodedata.category. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.
torch._dynamo.utils.warn_once(msg)