Spaces:
Runtime error
Runtime error
# Code adapted from the following sources: | |
# https://huggingface.co/huggan/fastgan-few-shot-fauvism-still-life | |
# https://huggingface.co/spaces/huggan/FastGan/ | |
import torch | |
from PIL import Image | |
from models import Generator | |
def load_img_generator(model_name_or_path): | |
generator = Generator(in_channels=256, out_channels=3) | |
generator = generator.from_pretrained(model_name_or_path, in_channels=256, out_channels=3) | |
_ = generator.eval() | |
return generator | |
def _denormalize(input: torch.Tensor) -> torch.Tensor: | |
return (input * 127.5) + 127.5 | |
def generate_img(device, gan_model): | |
img_generator = load_img_generator("huggan/fastgan-few-shot-"+gan_model) | |
noise = torch.zeros(1, 256, 1, 1, device=device).normal_(0.0, 1.0) | |
with torch.no_grad(): | |
gan_images, _ = img_generator(noise) | |
gan_image = _denormalize(gan_images.detach()).cpu().squeeze() | |
gan_image = gan_image.permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
gan_image = Image.fromarray(gan_image) | |
return gan_image | |