|
import os |
|
import sys |
|
import io |
|
import torch |
|
import torchvision |
|
import clip |
|
import numpy as np |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
from torchvision.transforms.functional import to_pil_image |
|
|
|
from utils import load_model_weights |
|
from model import NetG, CLIP_TXT_ENCODER |
|
|
|
|
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
repo_id = "VinayHajare/EfficientCLIP-GAN" |
|
file_name = "saved_models/state_epoch_1480.pth" |
|
|
|
|
|
clip_text = "ViT-B/32" |
|
clip_model, preprocessor = clip.load(clip_text, device=device) |
|
clip_model = clip_model.eval() |
|
text_encoder = CLIP_TXT_ENCODER(clip_model).to(device) |
|
|
|
|
|
model_path = hf_hub_download(repo_id = repo_id, filename = file_name, token = HF_TOKEN) |
|
checkpoint = torch.load(model_path, map_location=torch.device(device)) |
|
netG = NetG(64, 100, 512, 256, 3, False, clip_model).to(device) |
|
generator = load_model_weights(netG, checkpoint['model']['netG'], multi_gpus=False) |
|
|
|
|
|
def generate_image_from_text(caption, batch_size=4): |
|
|
|
noise = torch.randn((batch_size, 100)).to(device) |
|
with torch.no_grad(): |
|
|
|
tokenized_text = clip.tokenize([caption]).to(device) |
|
|
|
sent_emb, word_emb = text_encoder(tokenized_text) |
|
|
|
sent_emb = sent_emb.repeat(batch_size, 1) |
|
|
|
generated_images = generator(noise, sent_emb, eval=True).float() |
|
|
|
|
|
pil_images = [] |
|
for image_tensor in generated_images.unbind(0): |
|
|
|
image_tensor = image_tensor.data.clamp(-1, 1) |
|
image_tensor = (image_tensor + 1.0) / 2.0 |
|
|
|
|
|
image_numpy = image_tensor.permute(1, 2, 0).cpu().numpy() |
|
|
|
|
|
image_numpy = np.clip(image_numpy, 0, 1) |
|
|
|
|
|
pil_image = Image.fromarray((image_numpy * 255).astype(np.uint8)) |
|
|
|
pil_images.append(pil_image) |
|
|
|
return pil_images |
|
|
|
|
|
def generate_image_from_text_with_persistent_storage(caption, batch_size=4): |
|
|
|
noise = torch.randn((batch_size, 100)).to(device) |
|
with torch.no_grad(): |
|
|
|
tokenized_text = clip.tokenize([caption]).to(device) |
|
|
|
sent_emb, word_emb = text_encoder(tokenized_text) |
|
|
|
sent_emb = sent_emb.repeat(batch_size, 1) |
|
|
|
generated_images = generator(noise, sent_emb, eval=True).float() |
|
|
|
|
|
permanent_dir = "generated_images" |
|
if not os.path.exists(permanent_dir): |
|
os.makedirs(permanent_dir) |
|
|
|
image_paths = [] |
|
for idx, image_tensor in enumerate(generated_images.unbind(0)): |
|
|
|
image_path = os.path.join(permanent_dir, f"image_{idx}.png") |
|
torchvision.utils.save_image(image_tensor.data, image_path, value_range=(-1, 1), normalize=True) |
|
image_paths.append(image_path) |
|
|
|
return image_paths |