Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
from model import AestheticPredictorModel | |
HFREPO = "City96/CityAesthetics" | |
MODELS = [ | |
"CityAesthetics-Anime-v1.8", | |
] | |
class CityAestheticsPipeline: | |
""" | |
Demo pipeline for [image=>score] prediction | |
Accepts a list of model paths on initialization. | |
Resulting object can be called directly with a PIL image as the input. | |
Returns a dict with the model name as key and the score [0.0;1.0] as a value. | |
""" | |
def __init__(self, model_paths): | |
self.models = {} | |
for path in model_paths: | |
name = os.path.splitext(os.path.basename(path))[0] | |
self.models[name] = self.load_model(path) | |
clip_ver = "openai/clip-vit-large-patch14" | |
self.proc = CLIPImageProcessor.from_pretrained(clip_ver) | |
self.clip = CLIPVisionModelWithProjection.from_pretrained(clip_ver) | |
print("CityAesthetics: Pipeline init ok") # debug | |
def load_model(self, path): | |
sd = load_file(path) | |
assert tuple(sd["up.0.weight"].shape) == (1024, 768) # only allow CLIP ver | |
model = AestheticPredictorModel() | |
model.load_state_dict(sd) | |
model.eval() | |
return model | |
def __call__(self, raw): | |
img = self.proc(images=raw, return_tensors="pt") | |
with torch.no_grad(): | |
emb = self.clip(pixel_values=img["pixel_values"]) | |
emb = emb["image_embeds"].detach().cpu() | |
out = {} | |
for name, model in self.models.items(): | |
pred = model(emb) | |
out[name] = float(pred.squeeze(0)) | |
return out | |
def get_model_path(name): | |
fname = f"{name}.safetensors" | |
# local path: [models/AesPred-Anime-v1.8.safetensors] | |
path = os.path.join(os.path.dirname(os.path.realpath(__file__)),"models") | |
if os.path.isfile(os.path.join(path, fname)): | |
print("CityAesthetics: Using local model") | |
return os.path.join(path, fname) | |
# huggingface hub fallback | |
print("CityAesthetics: Using HF Hub model") | |
return str(hf_hub_download( | |
token = os.environ.get("HFS_TOKEN") or True, | |
repo_id = HFREPO, | |
filename = fname, | |
# subfolder = fname.split('-')[1], | |
)) | |
article = """\ | |
# About | |
This is the live demo for the CityAesthetics class of predictors. | |
For more information, you can check out the [Huggingface Hub](https://huggingface.co/city96/CityAesthetics) or [GitHub page](https://github.com/city96/CityAesthetics). | |
## CityAesthetics-Anime | |
This flavor is optimized for scoring anime images with at least one subject present. | |
### Intentional biases: | |
- Completely negative towards real life photos (ideal score of 0%) | |
- Strongly Negative towards text (subtitles, memes, etc) and manga panels | |
- Fairly negative towards 3D and to some extent 2.5D images | |
- Negative towards western cartoons and stylized images (chibi, parody) | |
### Expected output scores: | |
- Non-anime images should always score below 20% | |
- Sketches/rough lineart/oekaki get around 20-40% | |
- Flat shading/TV anime gets around 40-50% | |
- Above 50% is mostly scored based on my personal style preferences | |
### Issues: | |
- Tends to filter male characters. | |
- Requires at least 1 subject, won't work for scenery/landscapes. | |
- Noticeable positive bias towards anime characters with animal ears. | |
- Hit-or-miss with AI generated images due to style/quality not being correlated. | |
""" | |
pipeline = CityAestheticsPipeline([get_model_path(x) for x in MODELS]) | |
gr.Interface( | |
fn = pipeline, | |
title = "CityAesthetics demo", | |
article = article, | |
inputs = gr.Image(label="Input image", type="pil"), | |
outputs = gr.Label(label="Model prediction", show_label=False), | |
examples = "./examples", | |
allow_flagging = "never", | |
analytics_enabled = False, | |
).launch() | |