from math import ceil from multilingual_clip import pt_multilingual_clip import numpy as np import pandas as pd import streamlit as st import torch from transformers import AutoTokenizer, AutoModel st.set_page_config(layout="wide") def init(): st.session_state.current_page = 1 device = "cuda" if torch.cuda.is_available() else "cpu" # Load the open CLIP models ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus" ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider" st.session_state.ml_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained( ml_model_name ) st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name) st.session_state.ja_model = AutoModel.from_pretrained( ja_model_name, trust_remote_code=True ).to(device) st.session_state.ja_tokenizer = AutoTokenizer.from_pretrained( ja_model_name, trust_remote_code=True ) st.session_state.search_image_ids = [] # Load the image IDs st.session_state.images_info = pd.read_csv("./metadata.csv") st.session_state.images_info.set_index("filename", inplace=True) st.session_state.image_ids = list( open("./images_list.txt", "r", encoding="utf-8").read().strip().split("\n") ) # Load the image feature vectors ml_image_features = np.load("./multilingual_features.npy") ja_image_features = np.load("./hakuhodo_features.npy") # Convert features to Tensors: Float32 on CPU and Float16 on GPU if device == "cpu": ml_image_features = torch.from_numpy(ml_image_features).float().to(device) ja_image_features = torch.from_numpy(ja_image_features).float().to(device) else: ml_image_features = torch.from_numpy(ml_image_features).to(device) ja_image_features = torch.from_numpy(ja_image_features).to(device) st.session_state.ml_image_features = ml_image_features / ml_image_features.norm( dim=-1, keepdim=True ) st.session_state.ja_image_features = ja_image_features / ja_image_features.norm( dim=-1, keepdim=True ) if ( "ml_image_features" not in st.session_state or "ja_image_features" not in st.session_state ): with st.spinner("Loading models and data, please wait..."): init() # The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model. def encode_search_query(search_query, model_type): with torch.no_grad(): # Encode and normalize the search query using the multilingual model if model_type == "M-CLIP (multiple languages)": text_encoded = st.session_state.ml_model.forward( search_query, st.session_state.ml_tokenizer ) text_encoded /= text_encoded.norm(dim=-1, keepdim=True) else: # model_type == "J-CLIP (日本語 only)" t_text = st.session_state.ja_tokenizer( search_query, padding=True, return_tensors="pt" ) text_encoded = st.session_state.ja_model.get_text_features(**t_text) text_encoded /= text_encoded.norm(dim=-1, keepdim=True) # Retrieve the feature vector return text_encoded # The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images. def find_best_matches(text_features, image_features, image_ids): # Compute the similarity between the search query and each image using the Cosine similarity similarities = (image_features @ text_features.T).squeeze(1) # Sort the images by their similarity score best_image_idx = (-similarities).argsort() # Return the image IDs of the best matches return [[image_ids[i], similarities[i].item()] for i in best_image_idx] def clip_search(search_query): if st.session_state.search_field_value != search_query: st.session_state.search_field_value = search_query model_type = st.session_state.active_model if len(search_query) >= 1: text_features = encode_search_query(search_query, model_type) # Compute the similarity between the descrption and each photo using the Cosine similarity # similarities = list((text_features @ photo_features.T).squeeze(0)) # Sort the photos by their similarity score if model_type == "M-CLIP (multiple languages)": matches = find_best_matches( text_features, st.session_state.ml_image_features, st.session_state.image_ids, ) else: # model_type == "J-CLIP (日本語 only)" matches = find_best_matches( text_features, st.session_state.ja_image_features, st.session_state.image_ids, ) result_image_ids = [match[0] for match in matches] st.session_state.search_image_ids = result_image_ids def string_search(): clip_search(st.session_state.search_field_value) st.title("Explore Japanese visual aesthetics with CLIP models") search_row = st.columns([45, 10, 13, 7, 25], vertical_alignment="center") with search_row[0]: search_field = st.text_input( label="search", label_visibility="collapsed", placeholder="Type something, or click a suggested search below.", on_change=string_search, key="search_field_value", ) with search_row[1]: st.button("Search", on_click=string_search, use_container_width=True) with search_row[2]: st.empty() with search_row[3]: st.markdown("**CLIP Model:**") with search_row[4]: st.radio( "CLIP Model", options=["M-CLIP (multiple languages)", "J-CLIP (日本語 only)"], key="active_model", on_change=string_search, horizontal=True, label_visibility="collapsed", ) canned_searches = st.columns([12, 22, 22, 22, 22], vertical_alignment="center") with canned_searches[0]: st.markdown("**Suggested searches:**") if st.session_state.active_model == "M-CLIP (multiple languages)": with canned_searches[1]: st.button( "negative space", on_click=clip_search, args=["negative space"], use_container_width=True, ) with canned_searches[2]: st.button("間", on_click=clip_search, args=["間"], use_container_width=True) with canned_searches[3]: st.button("음각", on_click=clip_search, args=["음각"], use_container_width=True) with canned_searches[4]: st.button( "αρνητικός χώρος", on_click=clip_search, args=["αρνητικός χώρος"], use_container_width=True, ) else: with canned_searches[1]: st.button( "間", on_click=clip_search, args=["間"], use_container_width=True, ) with canned_searches[2]: st.button("奥", on_click=clip_search, args=["奥"], use_container_width=True) with canned_searches[3]: st.button("山", on_click=clip_search, args=["山"], use_container_width=True) with canned_searches[4]: st.button( "花に酔えり 羽織着て刀 さす女", on_click=clip_search, args=["花に酔えり 羽織着て刀 さす女"], use_container_width=True, ) controls = st.columns([35, 5, 35, 5, 20], gap="large", vertical_alignment="center") with controls[0]: im_per_pg = st.columns([30, 70], vertical_alignment="center") with im_per_pg[0]: st.markdown("**Images/page:**") with im_per_pg[1]: batch_size = st.select_slider( "Images/page:", range(10, 50, 10), label_visibility="collapsed" ) with controls[1]: st.empty() with controls[2]: im_per_row = st.columns([30, 70], vertical_alignment="center") with im_per_row[0]: st.markdown("**Images/row:**") with im_per_row[1]: row_size = st.select_slider( "Images/row:", range(1, 6), value=5, label_visibility="collapsed" ) num_batches = ceil(len(st.session_state.image_ids) / batch_size) with controls[3]: st.empty() with controls[4]: pager = st.columns([40, 60], vertical_alignment="center") with pager[0]: st.markdown(f"Page **{st.session_state.current_page}** of **{num_batches}** ") with pager[1]: st.number_input( "Page", min_value=1, max_value=num_batches, step=1, label_visibility="collapsed", key="current_page", ) if len(st.session_state.search_image_ids) == 0: batch = [] else: batch = st.session_state.search_image_ids[ (st.session_state.current_page - 1) * batch_size : st.session_state.current_page * batch_size ] grid = st.columns(row_size) col = 0 for image_id in batch: with grid[col]: link_text = st.session_state.images_info.loc[image_id]["permalink"].split("/")[ 2 ] st.html( f"""
{st.session_state.images_info.loc[image_id]['caption']}
""" ) st.caption( f"""
{link_text}
""", unsafe_allow_html=True, ) col = (col + 1) % row_size