jbilcke-hf's picture
jbilcke-hf HF staff
Upload 30 files
f08eddf verified
raw
history blame contribute delete
771 Bytes
from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG
def load_model(args, in_channels, out_channels, factor_kwargs):
"""load hunyuan video model
Args:
args (dict): model args
in_channels (int): input channels number
out_channels (int): output channels number
factor_kwargs (dict): factor kwargs
Returns:
model (nn.Module): The hunyuan video model
"""
if args.model in HUNYUAN_VIDEO_CONFIG.keys():
model = HYVideoDiffusionTransformer(
args,
in_channels=in_channels,
out_channels=out_channels,
**HUNYUAN_VIDEO_CONFIG[args.model],
**factor_kwargs,
)
return model
else:
raise NotImplementedError()