fffiloni commited on
Commit
94f8ab2
·
verified ·
1 Parent(s): ea50b7d

Do not load models on gpu at first

Browse files
Files changed (1) hide show
  1. models/utils.py +13 -8
models/utils.py CHANGED
@@ -81,36 +81,41 @@ def get_model(
81
  freeze_params(pipe.transformer.parameters())
82
  pipe.transformer.enable_gradient_checkpointing()
83
  #pipe = pipe.to(device)
 
84
  elif model_name == "hyper-sd":
85
  base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
86
  repo_name = "ByteDance/Hyper-SD"
87
  ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"
88
- # Load model.
 
89
  unet = UNet2DConditionModel.from_config(
90
  base_model_id, subfolder="unet", cache_dir=cache_dir
91
- ).to(device, dtype)
 
 
92
  unet.load_state_dict(
93
  load_file(
94
  hf_hub_download(repo_name, ckpt_name, cache_dir=cache_dir),
95
  device="cuda",
96
  )
97
  )
 
 
98
  pipe = RewardStableDiffusionXL.from_pretrained(
99
  base_model_id,
100
  unet=unet,
101
- torch_dtype=dtype,
102
- variant="fp16",
103
  cache_dir=cache_dir,
104
  is_hyper=True,
105
  memsave=memsave,
106
  )
 
107
  # Use LCM scheduler instead of ddim scheduler to support specific timestep number inputs
108
  pipe.scheduler = LCMScheduler.from_config(
109
  pipe.scheduler.config, cache_dir=cache_dir
110
  )
111
- #pipe = pipe.to(device, dtype)
112
- # upcast vae
113
- pipe.vae = pipe.vae.to(dtype=torch.float32)
114
  elif model_name == "flux":
115
  pipe = RewardFluxPipeline.from_pretrained(
116
  "black-forest-labs/FLUX.1-schnell",
@@ -187,4 +192,4 @@ def get_multi_apply_fn(
187
  generator=generator,
188
  )
189
  else:
190
- raise ValueError(f"Unknown model type: {model_type}")
 
81
  freeze_params(pipe.transformer.parameters())
82
  pipe.transformer.enable_gradient_checkpointing()
83
  #pipe = pipe.to(device)
84
+
85
  elif model_name == "hyper-sd":
86
  base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
87
  repo_name = "ByteDance/Hyper-SD"
88
  ckpt_name = "Hyper-SDXL-1step-Unet.safetensors"
89
+
90
+ # Load model but don't specify device or dtype (defaults to CPU and float32)
91
  unet = UNet2DConditionModel.from_config(
92
  base_model_id, subfolder="unet", cache_dir=cache_dir
93
+ )
94
+
95
+ # Load state dict into unet (stays on CPU by default)
96
  unet.load_state_dict(
97
  load_file(
98
  hf_hub_download(repo_name, ckpt_name, cache_dir=cache_dir),
99
  device="cuda",
100
  )
101
  )
102
+
103
+ # Initialize the pipeline (it will stay on CPU initially, using default dtype)
104
  pipe = RewardStableDiffusionXL.from_pretrained(
105
  base_model_id,
106
  unet=unet,
107
+ torch_dtype=torch.float16,
108
+ variant="fp16", # Still set fp16 for later use on GPU
109
  cache_dir=cache_dir,
110
  is_hyper=True,
111
  memsave=memsave,
112
  )
113
+
114
  # Use LCM scheduler instead of ddim scheduler to support specific timestep number inputs
115
  pipe.scheduler = LCMScheduler.from_config(
116
  pipe.scheduler.config, cache_dir=cache_dir
117
  )
118
+
 
 
119
  elif model_name == "flux":
120
  pipe = RewardFluxPipeline.from_pretrained(
121
  "black-forest-labs/FLUX.1-schnell",
 
192
  generator=generator,
193
  )
194
  else:
195
+ raise ValueError(f"Unknown model type: {model_type}")