import os import streamlit as st import torch import pandas as pd import numpy as np import requests from bokeh.plotting import figure, show from bokeh.models import HoverTool, ColumnDataSource, CustomJSHover from bokeh.embed import file_html from bokeh.resources import CDN # Import CDN here from datasets import load_dataset, Dataset, load_from_disk from huggingface_hub import login from sklearn.manifold import TSNE from tqdm import tqdm @st.cache_data def load_hf_dataset(): # login to huggingface login(token=os.environ.get("HF_TOKEN")) # load from huggingface roster = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Roster', split='train')) promptBook = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Metadata', split='train')) # process dataset roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']].drop_duplicates().reset_index(drop=True) # add 'custom_score_weights' column to promptBook if not exist if 'weighted_score_sum' not in promptBook.columns: promptBook.loc[:, 'weighted_score_sum'] = 0 # merge roster and promptbook promptBook = promptBook.merge(roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']], on=['model_id', 'modelVersion_id'], how='left') # add column to record current row index promptBook.loc[:, 'row_idx'] = promptBook.index return roster, promptBook def show_with_bokeh(data, streamlit=False): # Extract x, y coordinates and image URLs x_coords, y_coords, image_urls = zip(*data) # Create a ColumnDataSource source = ColumnDataSource(data=dict(x=x_coords, y=y_coords, image=image_urls)) # Create a figure p = figure(width=800, height=600) # Add scatter plot scatter = p.scatter(x='x', y='y', size=20, source=source) # Define hover tool hover = HoverTool() # hover.tooltips = """ #
# #
# """ # hover.formatters = {'@image': CustomJSHover(code=""" # const index = cb_data.index; # const url = cb_data.source.data['image'][index]; # return ''; # """)} hover.tooltips = """
""" hover.formatters = {'@image': CustomJSHover(code=""" const index = cb_data.index; const url = cb_data.source.data['image'][index]; return ''; """)} p.add_tools(hover) # Generate HTML with the plot html = file_html(p, CDN, "Interactive Scatter Plot with Hover Images") # Save the HTML file or show it # with open("scatter_plot_with_hover_images.html", "w") as f: # f.write(html) if streamlit: st.bokeh_chart(p, use_container_width=True) else: show(p) def show_with_bokeh_2(data, image_size=[40, 40], streamlit=False): # Extract x, y coordinates and image URLs x_coords, y_coords, image_urls = zip(*data) # Create a ColumnDataSource source = ColumnDataSource(data=dict(x=x_coords, y=y_coords, image=image_urls)) # Create a figure p = figure(width=800, height=600, aspect_ratio=1.0) # Add image glyphs # image_size = 40 # Adjust this size as needed scale = 0.1 image_size = [int(image_size[0])*scale, int(image_size[1])*scale] print(image_size) p.image_url(url='image', x='x', y='y', source=source, w=image_size[0], h=image_size[1], anchor="center") # Define hover tool hover = HoverTool() hover.tooltips = """
""" p.add_tools(hover) # Generate HTML with the plot html = file_html(p, CDN, "Scatter Plot with Images") # Save the HTML file or show it # with open("scatter_plot_with_images.html", "w") as f: # f.write(html) if streamlit: st.bokeh_chart(p, use_container_width=True) else: show(p) if __name__ == '__main__': # load dataset roster, promptBook = load_hf_dataset() print('==> loading feats') feats = {} for pt in os.listdir('../data/feats'): if pt.split('.')[-1] == 'pt' and pt.split('.')[0].isdigit(): feats[pt.split('.')[0]] = torch.load(os.path.join('../data/feats', pt)) print('==> applying t-SNE') # apply t-SNE to entries in each feat in feats to get 2D coordinates tsne = TSNE(n_components=2, random_state=0) # for k, v in tqdm(feats.items()): # feats[k]['tsne'] = tsne.fit_transform(v['all'].numpy()) prompt_id = '49' feats[prompt_id]['tsne'] = tsne.fit_transform(feats[prompt_id]['all'].numpy()) print(feats[prompt_id]['tsne']) keys = [] for k in feats[prompt_id].keys(): if k != 'all' and k != 'tsne': keys.append(int(k.item())) print(keys) data = [] for idx in range(len(keys)): modelVersion_id = keys[idx] image_id = promptBook[(promptBook['modelVersion_id'] == modelVersion_id) & (promptBook['prompt_id'] == int(prompt_id))].reset_index(drop=True).loc[0, 'image_id'] image_url = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{image_id}.png" scale = 50 data.append((feats[prompt_id]['tsne'][idx][0]*scale, feats[prompt_id]['tsne'][idx][1]*scale, image_url)) image_size = promptBook[(promptBook['image_id'] == image_id)].reset_index(drop=True).loc[0, 'size'].split('x') # # Sample data: (x, y) coordinates and corresponding image URLs # data = [ # (2, 5, "https://www.crunchyroll.com/imgsrv/display/thumbnail/480x720/catalog/crunchyroll/669dae5dbea3d93bb5f1012078501976.jpeg"), # (4, 8, "https://i.pinimg.com/originals/40/6d/38/406d38957bc4fd12f34c5dfa3d73b86d.jpg"), # (7, 3, "https://i.pinimg.com/550x/76/27/d2/7627d227adc6fb5fb6662ebfb9d82d7e.jpg"), # # Add more data points and image URLs # ] # show_with_bokeh(data, streamlit=True) show_with_bokeh_2(data, image_size=image_size, streamlit=True)