|
import streamlit as st |
|
|
|
import numpy as np |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
import json |
|
|
|
CONFIG_NAME = "config.json" |
|
revision = None |
|
cache_dir = None |
|
force_download = False |
|
proxies = None |
|
resume_download = False |
|
local_files_only = False |
|
token = None |
|
|
|
|
|
from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN |
|
|
|
def load_model(model_name="ceyda/butterfly_cropped_uniq1K_512"): |
|
|
|
""" |
|
Loads a pre-trained LightweightGAN model from Hugging Face Model Hub. |
|
|
|
Args: |
|
model_name (str): The name of the pre-trained model to load. Defaults to "ceyda/butterfly_cropped_uniq1K_512". |
|
model_version (str): The version of the pre-trained model to load. Defaults to None. |
|
|
|
Returns: |
|
LightweightGAN: The loaded pre-trained model. |
|
""" |
|
|
|
config_file = hf_hub_download( |
|
repo_id=str(model_name), |
|
filename=CONFIG_NAME, |
|
revision=revision, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
token=token, |
|
local_files_only=local_files_only, |
|
) |
|
with open(config_file, "r", encoding="utf-8") as f: |
|
config = json.load(f) |
|
|
|
|
|
gan = LightweightGAN(latent_dim=256, image_size=512) |
|
|
|
gan = gan._from_pretrained( |
|
model_id=str(model_name), |
|
revision=revision, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
local_files_only=local_files_only, |
|
token=token, |
|
use_auth_token=False, |
|
config=config, |
|
) |
|
|
|
gan.eval() |
|
return gan |
|
|
|
def generation(gan, batch_size=1): |
|
with torch.no_grad(): |
|
ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0.0, 1.0) * 255 |
|
ims = ims.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) |
|
return ims |