movie-diffusion / inference.py
CarlosMN's picture
Added genre selection + inpainting via notebook
d56e77f
raw
history blame
1.61 kB
import torch
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import numpy as np
from PIL import Image
import requests
import io
from unet import Unet, ConditionalUnet
from diffusion import GaussianDiffusion, DiffusionImageAPI
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def inference1():
# new image from web page
image = requests.get("https://picsum.photos/120/80").content
return Image.open(io.BytesIO(image))
def inference(cond, x0=None, gif=False):
model = Unet(
image_channels=3,
dropout=0.1,
)
model = ConditionalUnet(
unet=model,
num_classes=13,
)
model.load_state_dict(torch.load("./model_final.pt", map_location=device))
diffusion = GaussianDiffusion(
model=model,
noise_steps=1000,
beta_0=1e-4,
beta_T=0.02,
image_size=(192, 128),
)
if x0 is not None:
x0 = diffusion.normalize_image(x0)
x0 = x0.permute(2, 0, 1)
x0 = x0.unsqueeze(0)
model.to(device)
diffusion.to(device)
imageAPI = DiffusionImageAPI(diffusion)
new_images, versions = diffusion.sample(1,cond=cond,x0=x0)
if gif:
images = []
for image in versions:
images.append(imageAPI.tensor_to_image(image.squeeze(0)))
print(len(images))
print(images[0])
# make gif out of pillow images
images[0].save('./gif_output/versions.gif',
save_all=True,
append_images=images[1:],
duration=100,
loop=0)
return imageAPI.tensor_to_image(new_images.squeeze(0))
if __name__ == "__main__":
inference().show()