Spaces:
Running
Running
import os | |
import streamlit as st | |
import torch | |
import pandas as pd | |
import numpy as np | |
from datasets import load_dataset, Dataset, load_from_disk | |
from huggingface_hub import login | |
from streamlit_agraph import agraph, Node, Edge, Config | |
from sklearn.manifold import TSNE | |
def load_hf_dataset(): | |
# login to huggingface | |
login(token=os.environ.get("HF_TOKEN")) | |
# load from huggingface | |
roster = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Roster', split='train')) | |
promptBook = pd.DataFrame(load_dataset('MAPS-research/GEMRec-Metadata', split='train')) | |
# process dataset | |
roster = roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', | |
'model_download_count']].drop_duplicates().reset_index(drop=True) | |
# add 'custom_score_weights' column to promptBook if not exist | |
if 'weighted_score_sum' not in promptBook.columns: | |
promptBook.loc[:, 'weighted_score_sum'] = 0 | |
# merge roster and promptbook | |
promptBook = promptBook.merge(roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']], | |
on=['model_id', 'modelVersion_id'], how='left') | |
# add column to record current row index | |
promptBook.loc[:, 'row_idx'] = promptBook.index | |
return roster, promptBook | |
def calc_tsne(prompt_id): | |
print('==> loading feats') | |
feats = {} | |
for pt in os.listdir('../data/feats'): | |
if pt.split('.')[-1] == 'pt' and pt.split('.')[0].isdigit(): | |
feats[pt.split('.')[0]] = torch.load(os.path.join('../data/feats', pt)) | |
print('==> applying t-SNE') | |
# apply t-SNE to entries in each feat in feats to get 2D coordinates | |
tsne = TSNE(n_components=2, random_state=0) | |
# for k, v in tqdm(feats.items()): | |
# feats[k]['tsne'] = tsne.fit_transform(v['all'].numpy()) | |
# prompt_id = '90' | |
feats[prompt_id]['tsne'] = tsne.fit_transform(feats[prompt_id]['all'].numpy()) | |
feats_df = pd.DataFrame(feats[prompt_id]['tsne'], columns=['x', 'y']) | |
feats_df['prompt_id'] = prompt_id | |
keys = [] | |
for k in feats[prompt_id].keys(): | |
if k != 'all' and k != 'tsne': | |
keys.append(int(k.item())) | |
feats_df['modelVersion_id'] = keys | |
return feats_df | |
# print(feats[prompt_id]['tsne']) | |
if __name__ == '__main__': | |
st.set_page_config(layout="wide") | |
# load dataset | |
roster, promptBook = load_hf_dataset() | |
# prompt_id = '20' | |
with st.sidebar: | |
st.write('## Select Prompt') | |
prompts = promptBook['prompt_id'].unique().tolist() | |
# sort prompts by prompt_id | |
prompts.sort() | |
prompt_id = st.selectbox('Select Prompt', prompts, index=0) | |
physics = st.checkbox('Enable Physics') | |
feats_df = calc_tsne(str(prompt_id)) | |
# keys = [] | |
# for k in feats[prompt_id].keys(): | |
# if k != 'all' and k != 'tsne': | |
# keys.append(int(k.item())) | |
# print(keys) | |
data = [] | |
for idx in feats_df.index: | |
modelVersion_id = feats_df.loc[idx, 'modelVersion_id'] | |
image_id = promptBook[(promptBook['modelVersion_id'] == modelVersion_id) & ( | |
promptBook['prompt_id'] == int(prompt_id))].reset_index(drop=True).loc[0, 'image_id'] | |
image_url = f"https://modelcofferbucket.s3-accelerate.amazonaws.com/{image_id}.png" | |
scale = 50 | |
data.append((feats_df.loc[idx, 'x'] * scale, feats_df.loc[idx, 'y'] * scale, image_url)) | |
image_size = promptBook[(promptBook['image_id'] == image_id)].reset_index(drop=True).loc[0, 'size'].split('x') | |
nodes = [] | |
edges = [] | |
for d in data: | |
nodes.append( Node(id=d[2], | |
# label=str(items.loc[idx, 'model_name']), | |
size=20, | |
shape="image", | |
image=d[2], | |
x=[d[0]], | |
y=[d[1]], | |
fixed=False if physics else True, | |
color={'background': '#00000', 'border': '#ffffff'}, | |
shadow={'enabled': True, 'color': 'rgba(0,0,0,0.4)', 'size': 10, 'x': 1, 'y': 1}, | |
# borderWidth=1, | |
# shapeProperties={'useBorderWithImage': True}, | |
) | |
) | |
# nodes.append( Node(id="Spiderman", | |
# label="Peter Parker", | |
# size=25, | |
# shape="circularImage", | |
# image="http://marvel-force-chart.surge.sh/marvel_force_chart_img/top_spiderman.png") | |
# ) # includes **kwargs | |
# nodes.append( Node(id="Captain_Marvel", | |
# label="Carol Danvers", | |
# fixed=True, | |
# size=25, | |
# shape="circularImage", | |
# image="http://marvel-force-chart.surge.sh/marvel_force_chart_img/top_captainmarvel.png") | |
# ) | |
# edges.append( Edge(source="Captain_Marvel", | |
# label="friend_of", | |
# target="Spiderman", | |
# length=200, | |
# # **kwargs | |
# ) | |
# ) | |
# | |
config = Config(width='100%', | |
height=800, | |
directed=True, | |
physics=physics, | |
hierarchical=False, | |
# **kwargs | |
) | |
cols = st.columns([3, 1], gap='large') | |
with cols[0]: | |
return_value = agraph(nodes=nodes, | |
edges=edges, | |
config=config) | |
# st.write(return_value) | |
with cols[1]: | |
try: | |
st.image(return_value, use_column_width=True) | |
except: | |
st.write('No image selected') |