# system import os from pathlib import Path if not Path('./Text2Punk-final-7.pt').exists() and not Path('./clip-final.pt').exists(): os.system("gdown https://drive.google.com/uc?id=1--27E5dk8GzgvpVL0ofr-m631iymBpUH") os.system("gdown https://drive.google.com/uc?id=191a5lTsUPQ1hXaeo6kVNbo_W3WYuXsmF") # plot import matplotlib.pyplot as plt import numpy as np from PIL import Image # gradio import gradio as gr # text2punks utils from text2punks.utils import resize, to_pil_image, model_loader, generate_image # nobs to tune top_k = 0.8 temperature = 1.25 # helper functions def compose_predictions(images): increased_h = 0 b, c, h, w = *images.shape, image_grid = Image.new("RGB", (b*w*4, h*4), color=0) for i in range(b): # resize(images[i], 96) print(images[i].shape) img_ = to_pil_image(images[i]) image_grid.paste(img_, (i*w*4, increased_h)) return image_grid def run_inference(prompt, num_images=32, batch_size=32, num_preds=8): t2p_path, clip_path = './Text2Punk-final-7.pt', './clip-final.pt' text2punk, clip = model_loader(t2p_path, clip_path) images, _ = generate_image(prompt_text=prompt, top_k=top_k, temperature=temperature, num_images=num_images, batch_size=batch_size, top_prediction=num_preds, text2punk_model=text2punk, clip_model=clip) predictions = compose_predictions(images) output_title = f""" {prompt} """ return (output_title, predictions) outputs = [ gr.outputs.HTML(label=""), # To be used as title gr.outputs.Image(label=''), ] description = """ Text2Cryptopunks is an AI model that generates Cryptopunks images from text prompt: """ gr.Interface(run_inference, inputs=[gr.inputs.Textbox(label='type somthing like this : "An Ape CryptoPunk that has 2 Attributes, a Pigtails and a Medical Mask."')], outputs=outputs, title='Text2Cryptopunks', description=description, article="
Created by kTonpa | GitHub", layout='vertical', theme='huggingface', examples=[['Cute Alien cryptopunk that has a 2 Attributes, a Pipe, and a Beanie.'], ['A low resolution photo of punky-looking Ape that has 2 Attributes, a Beanie, and a Medical Mask.']], allow_flagging=False, live=False, # server_port=8999 ).launch(share=True)