WaterKnight's picture
Image features rebuild.
2e697e2
raw
history blame
3.8 kB
import os
from io import BytesIO
import requests
from datetime import datetime
import random
# Interface utilities
import gradio as gr
# Data utilities
import numpy as np
import pandas as pd
# Image utilities
from PIL import Image
import cv2
# Clip Model
import torch
from transformers import CLIPTokenizer, CLIPModel
# Style Transfer Model
import paddlehub as hub
os.system("hub install stylepro_artistic==1.0.1")
stylepro_artistic = hub.Module(name="stylepro_artistic")
# Clip Model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
model = model.to(device)
# Load Data
photos = pd.read_csv("unsplash-dataset/photos.tsv000", sep="\t", header=0)
photo_features = np.load("unsplash-dataset/features.npy")
photo_ids = pd.read_csv("unsplash-dataset/photo_ids.csv")
photo_ids = list(photo_ids["photo_id"])
def image_from_text(text_input):
start=datetime.now()
## Inference
with torch.no_grad():
inputs = tokenizer([text_input], padding=True, return_tensors="pt")
text_features = model.get_text_features(**inputs).cpu().numpy()
## Find similarity
similarities = list((text_features @ photo_features.T).squeeze(0))
## Return best image :)
idx = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)
idx = idx[random.randint(0,4)][1]
photo_id = photo_ids[idx]
photo_data = photos[photos["photo_id"] == photo_id].iloc[0]
print(f"Time spent at CLIP: {datetime.now()-start}")
start=datetime.now()
# Downlaod image
response = requests.get(photo_data["photo_image_url"] + "?w=640")
pil_image = Image.open(BytesIO(response.content)).convert("RGB")
open_cv_image = np.array(pil_image)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
print(f"Time spent at Image request: {datetime.now()-start}")
return open_cv_image
def inference(content, style):
content_image = image_from_text(content)
start=datetime.now()
result = stylepro_artistic.style_transfer(
images=[{
"content": content_image,
"styles": [cv2.imread(style.name)]
}])
print(f"Time spent at Style Transfer: {datetime.now()-start}")
return Image.fromarray(np.uint8(result[0]["data"])[:,:,::-1]).convert("RGB")
if __name__ == "__main__":
title = "Neural Style Transfer"
description = "Gradio demo for Neural Style Transfer. To use it, simply enter the text for image content and upload style image. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2003.07694'target='_blank'>Parameter-Free Style Projection for Arbitrary Style Transfer</a> | <a href='https://github.com/PaddlePaddle/PaddleHub' target='_blank'>Github Repo</a></br><a href='https://arxiv.org/abs/2103.00020'target='_blank'>Clip paper</a> | <a href='https://huggingface.co/transformers/model_doc/clip.html' target='_blank'>Hugging Face Clip Implementation</a></p>"
examples=[
["a cute kangaroo", "styles/starry.jpeg"],
["man holding beer", "styles/mona1.jpeg"],
]
interface = gr.Interface(inference,
inputs=[
gr.inputs.Textbox(lines=1, placeholder="Describe the content of the image", default="a cute kangaroo", label="Describe the image to which the style will be applied"),
gr.inputs.Image(type="file", label="Style to be applied"),
],
outputs=gr.outputs.Image(type="pil"),
enable_queue=True,
title=title,
description=description,
article=article,
examples=examples)
interface.launch()