S3Diff / src /model.py
zhangap's picture
Upload 213 files
36d9761 verified
raw
history blame
3.41 kB
import torch
import os
import requests
from tqdm import tqdm
from diffusers import DDPMScheduler, EulerDiscreteScheduler
from typing import Any, Optional, Union
# def make_1step_sched(pretrained_path, step=4):
# noise_scheduler_1step = EulerDiscreteScheduler.from_pretrained(pretrained_path, subfolder="scheduler")
# noise_scheduler_1step.set_timesteps(step, device="cuda")
# noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
# return noise_scheduler_1step
def make_1step_sched(pretrained_path):
noise_scheduler_1step = DDPMScheduler.from_pretrained(pretrained_path, subfolder="scheduler")
noise_scheduler_1step.set_timesteps(1, device="cuda")
noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
return noise_scheduler_1step
def my_lora_fwd(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
if not self.use_dora[active_adapter]:
_tmp = lora_A(dropout(x))
if isinstance(lora_A, torch.nn.Conv2d):
_tmp = torch.einsum('...khw,...kr->...rhw', _tmp, self.de_mod)
elif isinstance(lora_A, torch.nn.Linear):
_tmp = torch.einsum('...lk,...kr->...lr', _tmp, self.de_mod)
else:
raise NotImplementedError('only conv and linear are supported yet.')
result = result + lora_B(_tmp) * scaling
else:
x = dropout(x)
result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter)
result = result.to(torch_result_dtype)
return result
def download_url(url, outf):
if not os.path.exists(outf):
print(f"Downloading checkpoint to {outf}")
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(outf, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR, something went wrong")
print(f"Downloaded successfully to {outf}")
else:
print(f"Skipping download, {outf} already exists")