'''Image Completion Demo (ImageGPT) - Paper: https://arxiv.org/abs/2109.10282 - Code: https://huggingface.co/spaces/nielsr/imagegpt-completion --- - 2021-12-10 first created - examples changed ''' from PIL import Image import matplotlib.pyplot as plt import os import numpy as np from glob import glob import gradio as gr from loguru import logger import torch from transformers import ImageGPTFeatureExtractor, ImageGPTForCausalImageModeling # ========== Settings ========== EXAMPLE_DIR = 'examples' examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.jpg'))) # ========== Logger ========== logger.add('app.log', mode='a') logger.info('===== APP RESTARTED =====') # ========== Models ========== # MODEL_DIR = 'models' # os.environ['TORCH_HOME'] = MODEL_DIR # os.environ['TF_HOME'] = MODEL_DIR feature_extractor = ImageGPTFeatureExtractor.from_pretrained( "openai/imagegpt-medium", # cache_dir=MODEL_DIR ) model = ImageGPTForCausalImageModeling.from_pretrained( "openai/imagegpt-medium", # cache_dir=MODEL_DIR ) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(DEVICE) logger.info(f'model loaded (DEVICE:{DEVICE})') def process_image(image): logger.info('--- image file received') # prepare 7 images, shape (7, 1024) batch_size = 7 encoding = feature_extractor([image for _ in range(batch_size)], return_tensors="pt") # create primers samples = encoding.pixel_values.numpy() n_px = feature_extractor.size clusters = feature_extractor.clusters n_px_crop = 16 primers = samples.reshape(-1,n_px*n_px)[:,:n_px_crop*n_px] # crop top n_px_crop rows. These will be the conditioning tokens # get conditioned image (from first primer tensor), padded with black pixels to be 32x32 primers_img = np.reshape(np.rint(127.5 * (clusters[primers[0]] + 1.0)), [n_px_crop,n_px, 3]).astype(np.uint8) primers_img = np.pad(primers_img, pad_width=((0,16), (0,0), (0,0)), mode="constant") # generate (no beam search) context = np.concatenate((np.full((batch_size, 1), model.config.vocab_size - 1), primers), axis=1) context = torch.tensor(context).to(DEVICE) output = model.generate(input_ids=context, max_length=n_px*n_px + 1, temperature=1.0, do_sample=True, top_k=40) # decode back to images (convert color cluster tokens back to pixels) samples = output[:,1:].cpu().detach().numpy() samples_img = [np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [n_px, n_px, 3]).astype(np.uint8) for s in samples] samples_img = [primers_img] + samples_img # stack images horizontally row1 = np.hstack(samples_img[:4]) row2 = np.hstack(samples_img[4:]) result = np.vstack([row1, row2]) # return as PIL Image completion = Image.fromarray(result) logger.info('--- image generated') return completion iface = gr.Interface( process_image, title="이미지의 절반을 지우고 절반을 채워 넣어주는 Image Completion 데모입니다 (ImageGPT)", description='주어진 이미지의 절반 아래를 AI가 채워 넣어줍니다 (CPU로 약 100초 정도 소요됩니다)', inputs=gr.inputs.Image(type="pil", label='인풋 이미지'), outputs=gr.outputs.Image(type="pil", label='AI가 그린 결과'), examples=examples, enable_queue=True, article='
Based on 🤗 Link
', ) if __name__ == '__main__': iface.launch(debug=True)