|
from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG |
|
|
|
|
|
def load_model(args, in_channels, out_channels, factor_kwargs): |
|
"""load hunyuan video model |
|
|
|
Args: |
|
args (dict): model args |
|
in_channels (int): input channels number |
|
out_channels (int): output channels number |
|
factor_kwargs (dict): factor kwargs |
|
|
|
Returns: |
|
model (nn.Module): The hunyuan video model |
|
""" |
|
if args.model in HUNYUAN_VIDEO_CONFIG.keys(): |
|
model = HYVideoDiffusionTransformer( |
|
args, |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
**HUNYUAN_VIDEO_CONFIG[args.model], |
|
**factor_kwargs, |
|
) |
|
return model |
|
else: |
|
raise NotImplementedError() |
|
|