Error when doing Torch Compile

#422
by mghaff - opened

My question is what is the right way of using FluxPipeline with torch.compile?

Here is my attempt. I am running Flux on H100 GPUs. Below is my code:

flux_model_pipe = FluxPipeline.from_pretrained(
    pretrained_model_name_or_path=model_blob_path,
    torch_dtype=torch.float16,
    local_files_only=True,
)
flux_model_pipe.to(device)
flux_model_pipe = torch.compile(flux_model_pipe)

Doing torch.compile reduces the performance (7s to 120s) while I expected it to make it faster.

Here is the warning I get that I think is causing this:

.venv/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py:679: UserWarning: Graph break due to unsupported builtin unicodedata.category. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.
  torch._dynamo.utils.warn_once(msg)
Your need to confirm your account before you can post a new comment.

Sign up or log in to comment