import os | |
import streamlit as st | |
import torch | |
import pandas as pd | |
import numpy as np | |
import requests | |
from bokeh.plotting import figure, show | |
from bokeh.models import HoverTool, ColumnDataSource, CustomJSHover | |
from bokeh.embed import file_html | |
from bokeh.resources import CDN # Import CDN here | |
from datasets import load_dataset, Dataset, load_from_disk | |
from huggingface_hub import login | |
from sklearn.manifold import TSNE | |
from tqdm import tqdm | |
def load_hf_dataset(): | |
# login to huggingface | |
login(token=os.environ.get("HF_TOKEN")) | |
# load from huggingface | |
roster = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Roster', split='train')) | |
promptBook = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Metadata', split='train')) | |
# process dataset | |
roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', | |
'model_download_count']].drop_duplicates().reset_index(drop=True) | |
# add 'custom_score_weights' column to promptBook if not exist | |
if 'weighted_score_sum' not in promptBook.columns: | |
promptBook.loc[:, 'weighted_score_sum'] = 0 | |
# merge roster and promptbook | |
promptBook = promptBook.merge(roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']], | |
on=['model_id', 'modelVersion_id'], how='left') | |
# add column to record current row index | |
promptBook.loc[:, 'row_idx'] = promptBook.index | |
return roster, promptBook | |
def show_with_bokeh(data, streamlit=False): | |
# Extract x, y coordinates and image URLs | |
x_coords, y_coords, image_urls = zip(*data) | |
# Create a ColumnDataSource | |
source = ColumnDataSource(data=dict(x=x_coords, y=y_coords, image=image_urls)) | |
# Create a figure | |
p = figure(width=800, height=600) | |
# Add scatter plot | |
scatter = p.scatter(x='x', y='y', size=20, source=source) | |
# Define hover tool | |
hover = HoverTool() | |
# hover.tooltips = """ | |
# <div> | |
# <iframe src="@image" width="512" height="512"></iframe> | |
# </div> | |
# """ | |
# hover.formatters = {'@image': CustomJSHover(code=""" | |
# const index = cb_data.index; | |
# const url =['image'][index]; | |
# return '<iframe src="' + url + '" width="512" height="512"></iframe>'; | |
# """)} | |
hover.tooltips = """ | |
<div> | |
<img src="@image" style='object-fit: contain'; height=100%"> | |
</div> | |
""" | |
hover.formatters = {'@image': CustomJSHover(code=""" | |
const index = cb_data.index; | |
const url =['image'][index]; | |
return '<img src="' + url + '">'; | |
""")} | |
p.add_tools(hover) | |
# Generate HTML with the plot | |
html = file_html(p, CDN, "Interactive Scatter Plot with Hover Images") | |
# Save the HTML file or show it | |
# with open("scatter_plot_with_hover_images.html", "w") as f: | |
# f.write(html) | |
if streamlit: | |
st.bokeh_chart(p, use_container_width=True) | |
else: | |
show(p) | |
def show_with_bokeh_2(data, image_size=[40, 40], streamlit=False): | |
# Extract x, y coordinates and image URLs | |
x_coords, y_coords, image_urls = zip(*data) | |
# Create a ColumnDataSource | |
source = ColumnDataSource(data=dict(x=x_coords, y=y_coords, image=image_urls)) | |
# Create a figure | |
p = figure(width=800, height=600, aspect_ratio=1.0) | |
# Add image glyphs | |
# image_size = 40 # Adjust this size as needed | |
scale = 0.1 | |
image_size = [int(image_size[0])*scale, int(image_size[1])*scale] | |
print(image_size) | |
p.image_url(url='image', x='x', y='y', source=source, w=image_size[0], h=image_size[1], anchor="center") | |
# Define hover tool | |
hover = HoverTool() | |
hover.tooltips = """ | |
<div> | |
<img src="@image" style='object-fit: contain'; height=100%'"> | |
</div> | |
""" | |
p.add_tools(hover) | |
# Generate HTML with the plot | |
html = file_html(p, CDN, "Scatter Plot with Images") | |
# Save the HTML file or show it | |
# with open("scatter_plot_with_images.html", "w") as f: | |
# f.write(html) | |
if streamlit: | |
st.bokeh_chart(p, use_container_width=True) | |
else: | |
show(p) | |
if __name__ == '__main__': | |
# load dataset | |
roster, promptBook = load_hf_dataset() | |
print('==> loading feats') | |
feats = {} | |
for pt in os.listdir('../data/feats'): | |
if pt.split('.')[-1] == 'pt' and pt.split('.')[0].isdigit(): | |
feats[pt.split('.')[0]] = torch.load(os.path.join('../data/feats', pt)) | |
print('==> applying t-SNE') | |
# apply t-SNE to entries in each feat in feats to get 2D coordinates | |
tsne = TSNE(n_components=2, random_state=0) | |
# for k, v in tqdm(feats.items()): | |
# feats[k]['tsne'] = tsne.fit_transform(v['all'].numpy()) | |
prompt_id = '49' | |
feats[prompt_id]['tsne'] = tsne.fit_transform(feats[prompt_id]['all'].numpy()) | |
print(feats[prompt_id]['tsne']) | |
keys = [] | |
for k in feats[prompt_id].keys(): | |
if k != 'all' and k != 'tsne': | |
keys.append(int(k.item())) | |
print(keys) | |
data = [] | |
for idx in range(len(keys)): | |
modelVersion_id = keys[idx] | |
image_id = promptBook[(promptBook['modelVersion_id'] == modelVersion_id) & (promptBook['prompt_id'] == int(prompt_id))].reset_index(drop=True).loc[0, 'image_id'] | |
image_url = f"{image_id}.png" | |
scale = 50 | |
data.append((feats[prompt_id]['tsne'][idx][0]*scale, feats[prompt_id]['tsne'][idx][1]*scale, image_url)) | |
image_size = promptBook[(promptBook['image_id'] == image_id)].reset_index(drop=True).loc[0, 'size'].split('x') | |
# # Sample data: (x, y) coordinates and corresponding image URLs | |
# data = [ | |
# (2, 5, ""), | |
# (4, 8, ""), | |
# (7, 3, ""), | |
# # Add more data points and image URLs | |
# ] | |
# show_with_bokeh(data, streamlit=True) | |
show_with_bokeh_2(data, image_size=image_size, streamlit=True) |