File size: 1,719 Bytes
019d58e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import functools
import unittest
import torch
from diffusers import CogVideoXTransformer3DModel, DiffusionPipeline
from para_attn.first_block_cache import utils
def apply_cache_on_transformer(
transformer: CogVideoXTransformer3DModel,
*,
residual_diff_threshold=0.04,
):
cached_transformer_blocks = torch.nn.ModuleList(
[
utils.CachedTransformerBlocks(
transformer.transformer_blocks,
transformer=transformer,
residual_diff_threshold=residual_diff_threshold,
)
]
)
original_forward = transformer.forward
@functools.wraps(transformer.__class__.forward)
def new_forward(
self,
*args,
**kwargs,
):
with unittest.mock.patch.object(
self,
"transformer_blocks",
cached_transformer_blocks,
):
return original_forward(
*args,
**kwargs,
)
transformer.forward = new_forward.__get__(transformer)
return transformer
def apply_cache_on_pipe(
pipe: DiffusionPipeline,
*,
shallow_patch: bool = False,
**kwargs,
):
original_call = pipe.__class__.__call__
if not getattr(original_call, "_is_cached", False):
@functools.wraps(original_call)
def new_call(self, *args, **kwargs):
with utils.cache_context(utils.create_cache_context()):
return original_call(self, *args, **kwargs)
pipe.__class__.__call__ = new_call
new_call._is_cached = True
if not shallow_patch:
apply_cache_on_transformer(pipe.transformer, **kwargs)
pipe._is_cached = True
return pipe
|