Imag / load_models.py
Baraaqasem's picture
Upload 29 files
14ee1a9 verified
raw
history blame
3.61 kB
import os
import sys
import traceback
import torch
def load_all_models():
"""
Download models of Lavie, VideoCrafter2, SEINE, ModelScope, and DynamiCrafter,
into the directory defined by MODEL_PATH,
with cuda cache emptied.
Returns: None
"""
sys.path.insert(0, './src/')
#from src.videogen_hub.infermodels import CogVideo
from src.videogen_hub.infermodels import ConsistI2V
from src.videogen_hub.infermodels import DynamiCrafter
from src.videogen_hub.infermodels import I2VGenXL
from src.videogen_hub.infermodels import LaVie
from src.videogen_hub.infermodels import ModelScope
from src.videogen_hub.infermodels import OpenSora
from src.videogen_hub.infermodels import OpenSoraPlan
from src.videogen_hub.infermodels import SEINE
from src.videogen_hub.infermodels import ShowOne
from src.videogen_hub.infermodels import StreamingT2V
from src.videogen_hub.infermodels import T2VTurbo
from src.videogen_hub.infermodels import VideoCrafter2
from src.videogen_hub import MODEL_PATH
try:
ConsistI2V()
except:
pass
torch.cuda.empty_cache()
assert os.path.exists(os.path.join(MODEL_PATH, 'ConsistI2V'))
print("ConsistI2V has already been downloaded!")
try:
DynamiCrafter()
except:
pass
torch.cuda.empty_cache()
assert os.path.exists(os.path.join(MODEL_PATH, 'dynamicrafter_256_v1'))
print("DynamiCrafter has already been downloaded!")
try:
I2VGenXL()
except:
pass
torch.cuda.empty_cache()
assert os.path.exists(os.path.join(MODEL_PATH, 'i2vgen-xl'))
print("I2VGenXL has already been downloaded!")
try:
LaVie()
except:
pass
torch.cuda.empty_cache()
assert os.path.exists(os.path.join(MODEL_PATH, 'lavie'))
print("Lavie Model has already been downloaded!")
try:
ModelScope()
except:
pass
torch.cuda.empty_cache()
assert os.path.exists(os.path.join(MODEL_PATH, 'modelscope'))
print("ModelScope has already been downloaded!")
try:
SEINE()
except:
pass
torch.cuda.empty_cache()
assert os.path.exists(os.path.join(MODEL_PATH, 'SEINE'))
print("SEINE has already been downloaded!")
try:
ShowOne()
except:
pass
torch.cuda.empty_cache()
assert os.path.exists(os.path.join(MODEL_PATH, 'showlab'))
print("ShowOne has already been downloaded!")
try:
StreamingT2V()
except:
pass
torch.cuda.empty_cache()
assert os.path.exists(os.path.join(MODEL_PATH, 'streamingtv2'))
print("StreamingTV has already been downloaded!")
try:
T2VTurbo()
except:
pass
torch.cuda.empty_cache()
assert os.path.exists(os.path.join(MODEL_PATH, 'T2V-Turbo-VC2'))
print("T2VTurbo has already been downloaded!")
try:
VideoCrafter2()
except:
pass
torch.cuda.empty_cache()
assert os.path.exists(os.path.join(MODEL_PATH, 'videocrafter2'))
print("VideoCrafter has already been downloaded!")
# Do these last, as they're linux-only...
try:
OpenSora()
except:
pass
torch.cuda.empty_cache()
assert os.path.exists(os.path.join(MODEL_PATH, 'STDiT2-XL_2'))
print("OpenSora has already been downloaded!")
try:
OpenSoraPlan()
except:
pass
torch.cuda.empty_cache()
assert os.path.exists(os.path.join(MODEL_PATH, 'Open-Sora-Plan-v1.1.0'))
print("OpenSoraPlan has already been downloaded!")
if __name__ == '__main__':
load_all_models()