|
from base64 import b64encode |
|
from io import BytesIO |
|
from math import ceil |
|
|
|
from multilingual_clip import pt_multilingual_clip |
|
import numpy as np |
|
import pandas as pd |
|
from PIL import Image |
|
import requests |
|
import streamlit as st |
|
import torch |
|
from torchvision.transforms import ToPILImage |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
from CLIP_Explainability.clip_ import load, tokenize |
|
from CLIP_Explainability.vit_cam import ( |
|
interpret_vit, |
|
vit_perword_relevance, |
|
) |
|
|
|
MAX_IMG_WIDTH = 500 |
|
MAX_IMG_HEIGHT = 800 |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
|
|
def find_best_matches(text_features, image_features, image_ids): |
|
|
|
similarities = (image_features @ text_features.T).squeeze(1) |
|
|
|
|
|
best_image_idx = (-similarities).argsort() |
|
|
|
|
|
return [[image_ids[i], similarities[i].item()] for i in best_image_idx] |
|
|
|
|
|
|
|
def encode_search_query(search_query, model_type): |
|
with torch.no_grad(): |
|
|
|
if model_type == "M-CLIP (multiple languages)": |
|
text_encoded = st.session_state.ml_model.forward( |
|
search_query, st.session_state.ml_tokenizer |
|
) |
|
text_encoded /= text_encoded.norm(dim=-1, keepdim=True) |
|
else: |
|
t_text = st.session_state.ja_tokenizer( |
|
search_query, padding=True, return_tensors="pt" |
|
) |
|
text_encoded = st.session_state.ja_model.get_text_features(**t_text) |
|
text_encoded /= text_encoded.norm(dim=-1, keepdim=True) |
|
|
|
|
|
return text_encoded |
|
|
|
|
|
def clip_search(search_query): |
|
if st.session_state.search_field_value != search_query: |
|
st.session_state.search_field_value = search_query |
|
|
|
model_type = st.session_state.active_model |
|
|
|
if len(search_query) >= 1: |
|
text_features = encode_search_query(search_query, model_type) |
|
|
|
|
|
|
|
|
|
|
|
if model_type == "M-CLIP (multiple languages)": |
|
matches = find_best_matches( |
|
text_features, |
|
st.session_state.ml_image_features, |
|
st.session_state.image_ids, |
|
) |
|
else: |
|
matches = find_best_matches( |
|
text_features, |
|
st.session_state.ja_image_features, |
|
st.session_state.image_ids, |
|
) |
|
|
|
st.session_state.search_image_ids = [match[0] for match in matches] |
|
st.session_state.search_image_scores = {match[0]: match[1] for match in matches} |
|
|
|
|
|
def string_search(): |
|
if "search_field_value" in st.session_state: |
|
clip_search(st.session_state.search_field_value) |
|
|
|
|
|
def load_image_features(): |
|
|
|
if st.session_state.vision_mode == "tiled": |
|
ml_image_features = np.load("./image_features/tiled_ml_features.npy") |
|
ja_image_features = np.load("./image_features/tiled_ja_features.npy") |
|
elif st.session_state.vision_mode == "stretched": |
|
ml_image_features = np.load("./image_features/resized_ml_features.npy") |
|
ja_image_features = np.load("./image_features/resized_ja_features.npy") |
|
else: |
|
ml_image_features = np.load("./image_features/cropped_ml_features.npy") |
|
ja_image_features = np.load("./image_features/cropped_ja_features.npy") |
|
|
|
|
|
device = st.session_state.device |
|
if device == "cpu": |
|
ml_image_features = torch.from_numpy(ml_image_features).float().to(device) |
|
ja_image_features = torch.from_numpy(ja_image_features).float().to(device) |
|
else: |
|
ml_image_features = torch.from_numpy(ml_image_features).to(device) |
|
ja_image_features = torch.from_numpy(ja_image_features).to(device) |
|
|
|
st.session_state.ml_image_features = ml_image_features / ml_image_features.norm( |
|
dim=-1, keepdim=True |
|
) |
|
st.session_state.ja_image_features = ja_image_features / ja_image_features.norm( |
|
dim=-1, keepdim=True |
|
) |
|
|
|
string_search() |
|
|
|
|
|
def init(): |
|
st.session_state.current_page = 1 |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
st.session_state.device = device |
|
|
|
|
|
ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus" |
|
ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt" |
|
|
|
with st.spinner("Loading models and data, please wait..."): |
|
st.session_state.ml_image_model, st.session_state.ml_image_preprocess = load( |
|
ml_model_path, device=device, jit=False |
|
) |
|
|
|
st.session_state.ml_model = ( |
|
pt_multilingual_clip.MultilingualCLIP.from_pretrained(ml_model_name) |
|
) |
|
st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name) |
|
|
|
ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider" |
|
ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin" |
|
|
|
st.session_state.ja_image_model, st.session_state.ja_image_preprocess = load( |
|
ja_model_path, device=device, jit=False |
|
) |
|
|
|
st.session_state.ja_model = AutoModel.from_pretrained( |
|
ja_model_name, trust_remote_code=True |
|
).to(device) |
|
st.session_state.ja_tokenizer = AutoTokenizer.from_pretrained( |
|
ja_model_name, trust_remote_code=True |
|
) |
|
|
|
|
|
st.session_state.images_info = pd.read_csv("./metadata.csv") |
|
st.session_state.images_info.set_index("filename", inplace=True) |
|
|
|
with open("./images_list.txt", "r", encoding="utf-8") as images_list: |
|
st.session_state.image_ids = list(images_list.read().strip().split("\n")) |
|
|
|
st.session_state.active_model = "M-CLIP (multiple languages)" |
|
|
|
st.session_state.vision_mode = "tiled" |
|
st.session_state.search_image_ids = [] |
|
st.session_state.search_image_scores = {} |
|
st.session_state.activations_image = None |
|
st.session_state.text_table_df = None |
|
|
|
with st.spinner("Loading models and data, please wait..."): |
|
load_image_features() |
|
|
|
|
|
if "images_info" not in st.session_state: |
|
init() |
|
|
|
|
|
def visualize_gradcam(viz_image_id): |
|
if "search_field_value" not in st.session_state: |
|
return |
|
|
|
header_cols = st.columns([80, 20], vertical_alignment="bottom") |
|
with header_cols[0]: |
|
st.title("Image + query details") |
|
with header_cols[1]: |
|
if st.button("Close"): |
|
st.rerun() |
|
|
|
st.markdown( |
|
f"**Query text:** {st.session_state.search_field_value} | **Image relevance:** {round(st.session_state.search_image_scores[viz_image_id], 3)}" |
|
) |
|
|
|
|
|
info_text = st.text("Calculating activation regions...") |
|
|
|
image_url = st.session_state.images_info.loc[viz_image_id]["image_url"] |
|
image_response = requests.get(image_url) |
|
image = Image.open(BytesIO(image_response.content), formats=["JPEG", "GIF"]) |
|
image = image.convert("RGB") |
|
|
|
img_dim = 224 |
|
if st.session_state.active_model == "M-CLIP (multiple languages)": |
|
img_dim = 240 |
|
|
|
orig_img_dims = image.size |
|
|
|
|
|
tile_behavior = None |
|
|
|
if st.session_state.vision_mode == "tiled": |
|
scaled_dims = [img_dim, img_dim] |
|
|
|
if orig_img_dims[0] > orig_img_dims[1]: |
|
scale_ratio = round(orig_img_dims[0] / orig_img_dims[1]) |
|
if scale_ratio > 1: |
|
scaled_dims = [scale_ratio * img_dim, img_dim] |
|
tile_behavior = "width" |
|
elif orig_img_dims[0] < orig_img_dims[1]: |
|
scale_ratio = round(orig_img_dims[1] / orig_img_dims[0]) |
|
if scale_ratio > 1: |
|
scaled_dims = [img_dim, scale_ratio * img_dim] |
|
tile_behavior = "height" |
|
|
|
resized_image = image.resize(scaled_dims, Image.LANCZOS) |
|
|
|
if tile_behavior == "width": |
|
image_tiles = [] |
|
for x in range(0, scale_ratio): |
|
box = (x * img_dim, 0, (x + 1) * img_dim, img_dim) |
|
image_tiles.append(resized_image.crop(box)) |
|
|
|
elif tile_behavior == "height": |
|
image_tiles = [] |
|
for y in range(0, scale_ratio): |
|
box = (0, y * img_dim, img_dim, (y + 1) * img_dim) |
|
image_tiles.append(resized_image.crop(box)) |
|
|
|
else: |
|
image_tiles = [resized_image] |
|
|
|
elif st.session_state.vision_mode == "stretched": |
|
image_tiles = [image.resize((img_dim, img_dim), Image.LANCZOS)] |
|
|
|
else: |
|
if orig_img_dims[0] > orig_img_dims[1]: |
|
scale_factor = orig_img_dims[0] / orig_img_dims[1] |
|
resized_img_dims = (round(scale_factor * img_dim), img_dim) |
|
resized_img = image.resize(resized_img_dims) |
|
elif orig_img_dims[0] < orig_img_dims[1]: |
|
scale_factor = orig_img_dims[1] / orig_img_dims[0] |
|
resized_img_dims = (img_dim, round(scale_factor * img_dim)) |
|
else: |
|
resized_img_dims = (img_dim, img_dim) |
|
|
|
resized_img = image.resize(resized_img_dims) |
|
|
|
left = round((resized_img_dims[0] - img_dim) / 2) |
|
top = round((resized_img_dims[1] - img_dim) / 2) |
|
x_right = round(resized_img_dims[0] - img_dim) - left |
|
x_bottom = round(resized_img_dims[1] - img_dim) - top |
|
right = resized_img_dims[0] - x_right |
|
bottom = resized_img_dims[1] - x_bottom |
|
|
|
|
|
image_tiles = [resized_img.crop((left, top, right, bottom))] |
|
|
|
image_visualizations = [] |
|
|
|
if st.session_state.active_model == "M-CLIP (multiple languages)": |
|
|
|
tokenized_text = st.session_state.ml_tokenizer.tokenize( |
|
st.session_state.search_field_value |
|
) |
|
|
|
text_features = st.session_state.ml_model.forward( |
|
st.session_state.search_field_value, st.session_state.ml_tokenizer |
|
) |
|
|
|
image_model = st.session_state.ml_image_model |
|
|
|
image_model.eval() |
|
|
|
for altered_image in image_tiles: |
|
image_model.zero_grad() |
|
|
|
p_image = ( |
|
st.session_state.ml_image_preprocess(altered_image) |
|
.unsqueeze(0) |
|
.to(st.session_state.device) |
|
) |
|
|
|
vis_t = interpret_vit( |
|
p_image.type(st.session_state.ml_image_model.dtype), |
|
text_features, |
|
image_model.visual, |
|
st.session_state.device, |
|
img_dim=img_dim, |
|
) |
|
|
|
image_visualizations.append(vis_t) |
|
|
|
else: |
|
|
|
tokenized_text = st.session_state.ja_tokenizer.tokenize( |
|
st.session_state.search_field_value |
|
) |
|
|
|
t_text = st.session_state.ja_tokenizer( |
|
st.session_state.search_field_value, return_tensors="pt" |
|
) |
|
text_features = st.session_state.ja_model.get_text_features(**t_text) |
|
|
|
image_model = st.session_state.ja_image_model |
|
image_model.eval() |
|
|
|
for altered_image in image_tiles: |
|
image_model.zero_grad() |
|
|
|
p_image = ( |
|
st.session_state.ja_image_preprocess(altered_image) |
|
.unsqueeze(0) |
|
.to(st.session_state.device) |
|
) |
|
|
|
vis_t = interpret_vit( |
|
p_image.type(st.session_state.ja_image_model.dtype), |
|
text_features, |
|
image_model.visual, |
|
st.session_state.device, |
|
img_dim=img_dim, |
|
) |
|
|
|
image_visualizations.append(vis_t) |
|
|
|
transform = ToPILImage() |
|
|
|
vis_images = [transform(vis_t) for vis_t in image_visualizations] |
|
|
|
if st.session_state.vision_mode == "cropped": |
|
resized_img.paste(vis_images[0], (left, top)) |
|
vis_images = [resized_img] |
|
|
|
if orig_img_dims[0] > orig_img_dims[1]: |
|
scale_factor = MAX_IMG_WIDTH / orig_img_dims[0] |
|
scaled_dims = [MAX_IMG_WIDTH, int(orig_img_dims[1] * scale_factor)] |
|
else: |
|
scale_factor = MAX_IMG_HEIGHT / orig_img_dims[1] |
|
scaled_dims = [int(orig_img_dims[0] * scale_factor), MAX_IMG_HEIGHT] |
|
|
|
if tile_behavior == "width": |
|
vis_image = Image.new("RGB", (len(vis_images) * img_dim, img_dim)) |
|
for x, v_img in enumerate(vis_images): |
|
vis_image.paste(v_img, (x * img_dim, 0)) |
|
st.session_state.activations_image = vis_image.resize(scaled_dims) |
|
|
|
elif tile_behavior == "height": |
|
vis_image = Image.new("RGB", (img_dim, len(vis_images) * img_dim)) |
|
for y, v_img in enumerate(vis_images): |
|
vis_image.paste(v_img, (0, y * img_dim)) |
|
st.session_state.activations_image = vis_image.resize(scaled_dims) |
|
|
|
else: |
|
st.session_state.activations_image = vis_images[0].resize(scaled_dims) |
|
|
|
image_io = BytesIO() |
|
st.session_state.activations_image.save(image_io, "PNG") |
|
dataurl = "data:image/png;base64," + b64encode(image_io.getvalue()).decode("ascii") |
|
|
|
st.html( |
|
f"""<div style="display: flex; flex-direction: column; align-items: center;"> |
|
<img src="{dataurl}" /> |
|
</div>""" |
|
) |
|
|
|
info_text.empty() |
|
|
|
tokenized_text = [tok for tok in tokenized_text if tok != "▁"] |
|
|
|
if ( |
|
len(tokenized_text) > 1 |
|
and len(tokenized_text) < 15 |
|
and st.button( |
|
"Calculate text importance (may take some time)", |
|
) |
|
): |
|
search_tokens = [] |
|
token_scores = [] |
|
|
|
progress_text = f"Processing {len(tokenized_text)} text tokens" |
|
progress_bar = st.progress(0.0, text=progress_text) |
|
|
|
for t, tok in enumerate(tokenized_text): |
|
token = tok.replace("▁", "") |
|
word_rel = vit_perword_relevance( |
|
p_image, |
|
st.session_state.search_field_value, |
|
image_model, |
|
tokenize, |
|
st.session_state.device, |
|
token, |
|
data_only=True, |
|
img_dim=img_dim, |
|
) |
|
avg_score = np.mean(word_rel) |
|
if avg_score == 0 or np.isnan(avg_score): |
|
continue |
|
search_tokens.append(token) |
|
token_scores.append(1 / avg_score) |
|
|
|
progress_bar.progress( |
|
(t + 1) / len(tokenized_text), |
|
text=f"Processing token {t+1} of {len(tokenized_text)}", |
|
) |
|
progress_bar.empty() |
|
|
|
normed_scores = torch.softmax(torch.tensor(token_scores), dim=0) |
|
|
|
token_scores = [f"{round(score.item() * 100, 3)}%" for score in normed_scores] |
|
st.session_state.text_table_df = pd.DataFrame( |
|
{"token": search_tokens, "importance": token_scores} |
|
) |
|
|
|
st.markdown("**Importance of each text token to relevance score**") |
|
st.table(st.session_state.text_table_df) |
|
|
|
|
|
def format_vision_mode(mode_stub): |
|
return f"Vision mode: {mode_stub.capitalize()}" |
|
|
|
|
|
@st.dialog(" ", width="large") |
|
def image_modal(vis_image_id): |
|
visualize_gradcam(vis_image_id) |
|
|
|
|
|
st.title("Explore Japanese visual aesthetics with CLIP models") |
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
[data-testid=stImageCaption] { |
|
padding: 0 0 0 0; |
|
} |
|
[data-testid=stVerticalBlockBorderWrapper] { |
|
line-height: 1.2; |
|
} |
|
[data-testid=stVerticalBlock] { |
|
gap: .75rem; |
|
} |
|
[data-testid=baseButton-secondary] { |
|
min-height: 1rem; |
|
padding: 0 0.75rem; |
|
margin: 0 0 1rem 0; |
|
} |
|
div[aria-label="dialog"]>button[aria-label="Close"] { |
|
display: none; |
|
} |
|
[data-testid=stFullScreenFrame] { |
|
display: flex; |
|
flex-direction: column; |
|
align-items: center; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
search_row = st.columns([45, 5, 1, 15, 1, 8, 25], vertical_alignment="center") |
|
with search_row[0]: |
|
search_field = st.text_input( |
|
label="search", |
|
label_visibility="collapsed", |
|
placeholder="Type something, or click a suggested search below.", |
|
on_change=string_search, |
|
key="search_field_value", |
|
) |
|
with search_row[1]: |
|
st.button( |
|
"Search", on_click=string_search, use_container_width=True, type="primary" |
|
) |
|
with search_row[2]: |
|
st.empty() |
|
with search_row[3]: |
|
st.selectbox( |
|
"Vision mode:", |
|
options=["tiled", "stretched", "cropped"], |
|
key="vision_mode", |
|
help="How to consider images that aren't square", |
|
on_change=load_image_features, |
|
format_func=format_vision_mode, |
|
label_visibility="collapsed", |
|
) |
|
with search_row[4]: |
|
st.empty() |
|
with search_row[5]: |
|
st.markdown("**CLIP Model:**") |
|
with search_row[6]: |
|
st.radio( |
|
"CLIP Model", |
|
options=["M-CLIP (multiple languages)", "J-CLIP (日本語)"], |
|
key="active_model", |
|
on_change=string_search, |
|
horizontal=True, |
|
label_visibility="collapsed", |
|
) |
|
|
|
canned_searches = st.columns([12, 22, 22, 22, 22], vertical_alignment="top") |
|
with canned_searches[0]: |
|
st.markdown("**Suggested searches:**") |
|
if st.session_state.active_model == "M-CLIP (multiple languages)": |
|
with canned_searches[1]: |
|
st.button( |
|
"negative space", |
|
on_click=clip_search, |
|
args=["negative space"], |
|
use_container_width=True, |
|
) |
|
with canned_searches[2]: |
|
st.button("間", on_click=clip_search, args=["間"], use_container_width=True) |
|
with canned_searches[3]: |
|
st.button("음각", on_click=clip_search, args=["음각"], use_container_width=True) |
|
with canned_searches[4]: |
|
st.button( |
|
"αρνητικός χώρος", |
|
on_click=clip_search, |
|
args=["αρνητικός χώρος"], |
|
use_container_width=True, |
|
) |
|
else: |
|
with canned_searches[1]: |
|
st.button( |
|
"間", |
|
on_click=clip_search, |
|
args=["間"], |
|
use_container_width=True, |
|
) |
|
with canned_searches[2]: |
|
st.button("奥", on_click=clip_search, args=["奥"], use_container_width=True) |
|
with canned_searches[3]: |
|
st.button("山", on_click=clip_search, args=["山"], use_container_width=True) |
|
with canned_searches[4]: |
|
st.button( |
|
"花に酔えり 羽織着て刀 さす女", |
|
on_click=clip_search, |
|
args=["花に酔えり 羽織着て刀 さす女"], |
|
use_container_width=True, |
|
) |
|
|
|
controls = st.columns([35, 5, 35, 5, 20], gap="large", vertical_alignment="center") |
|
with controls[0]: |
|
im_per_pg = st.columns([30, 70], vertical_alignment="center") |
|
with im_per_pg[0]: |
|
st.markdown("**Images/page:**") |
|
with im_per_pg[1]: |
|
batch_size = st.select_slider( |
|
"Images/page:", range(10, 50, 10), label_visibility="collapsed" |
|
) |
|
with controls[1]: |
|
st.empty() |
|
with controls[2]: |
|
im_per_row = st.columns([30, 70], vertical_alignment="center") |
|
with im_per_row[0]: |
|
st.markdown("**Images/row:**") |
|
with im_per_row[1]: |
|
row_size = st.select_slider( |
|
"Images/row:", range(1, 6), value=5, label_visibility="collapsed" |
|
) |
|
num_batches = ceil(len(st.session_state.image_ids) / batch_size) |
|
with controls[3]: |
|
st.empty() |
|
with controls[4]: |
|
pager = st.columns([40, 60], vertical_alignment="center") |
|
with pager[0]: |
|
st.markdown(f"Page **{st.session_state.current_page}** of **{num_batches}** ") |
|
with pager[1]: |
|
st.number_input( |
|
"Page", |
|
min_value=1, |
|
max_value=num_batches, |
|
step=1, |
|
label_visibility="collapsed", |
|
key="current_page", |
|
) |
|
|
|
|
|
if len(st.session_state.search_image_ids) == 0: |
|
batch = [] |
|
else: |
|
batch = st.session_state.search_image_ids[ |
|
(st.session_state.current_page - 1) * batch_size : st.session_state.current_page |
|
* batch_size |
|
] |
|
|
|
grid = st.columns(row_size) |
|
col = 0 |
|
for image_id in batch: |
|
with grid[col]: |
|
link_text = st.session_state.images_info.loc[image_id]["permalink"].split("/")[ |
|
2 |
|
] |
|
|
|
|
|
|
|
|
|
st.html( |
|
f"""<div style="display: flex; flex-direction: column; align-items: center"> |
|
<img src="{st.session_state.images_info.loc[image_id]['image_url']}" style="max-width: 100%; max-height: {MAX_IMG_HEIGHT}px" /> |
|
<div>{st.session_state.images_info.loc[image_id]['caption']} <b>[{round(st.session_state.search_image_scores[image_id], 3)}]</b></div> |
|
</div>""" |
|
) |
|
st.caption( |
|
f"""<div style="display: flex; flex-direction: column; align-items: center; position: relative; top: -12px"> |
|
<a href="{st.session_state.images_info.loc[image_id]['permalink']}">{link_text}</a> |
|
<div>""", |
|
unsafe_allow_html=True, |
|
) |
|
st.button( |
|
"Explain this", |
|
on_click=image_modal, |
|
args=[image_id], |
|
use_container_width=True, |
|
key=image_id, |
|
) |
|
col = (col + 1) % row_size |
|
|