Spaces:
Runtime error
Runtime error
import torch | |
class TorchCompileModel: | |
def INPUT_TYPES(s): | |
return {"required": { "model": ("MODEL",), | |
"backend": (["inductor", "cudagraphs"],), | |
}} | |
RETURN_TYPES = ("MODEL",) | |
FUNCTION = "patch" | |
CATEGORY = "_for_testing" | |
EXPERIMENTAL = True | |
def patch(self, model, backend): | |
m = model.clone() | |
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend)) | |
return (m, ) | |
NODE_CLASS_MAPPINGS = { | |
"TorchCompileModel": TorchCompileModel, | |
} | |