Spaces:
Running
Running
update bar chart
Browse files- app.py +46 -52
- test_altair.py +22 -47
app.py
CHANGED
@@ -20,16 +20,38 @@ SCORE_NAME_MAPPING = {'clip': 'clip_score', 'rank': 'avg_rank', 'pop': 'model_do
|
|
20 |
|
21 |
# hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
|
22 |
@st.cache_resource
|
23 |
-
def altair_histogram(hist_data, sort_by):
|
24 |
brushed = alt.selection_interval(encodings=['x'], name="brushed")
|
25 |
-
|
|
|
26 |
alt.Chart(hist_data)
|
27 |
-
.mark_bar()
|
28 |
-
.encode(alt.X(f"{sort_by}:Q", bin=
|
29 |
-
.add_selection(brushed)
|
30 |
-
.properties(width=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
)
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
class GalleryApp:
|
34 |
def __init__(self, promptBook, images_ds):
|
35 |
self.promptBook = promptBook
|
@@ -169,7 +191,6 @@ class GalleryApp:
|
|
169 |
|
170 |
return items, info, col_num
|
171 |
|
172 |
-
|
173 |
def selection_panel_2(self, items):
|
174 |
selecters = st.columns([1, 5])
|
175 |
|
@@ -226,14 +247,25 @@ class GalleryApp:
|
|
226 |
items = items[items['checked'] == True].reset_index(drop=True)
|
227 |
print(items)
|
228 |
|
|
|
229 |
if sort_type == 'Scores':
|
230 |
-
st.
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
info = st.multiselect('Show Info',
|
239 |
['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
|
@@ -308,7 +340,6 @@ class GalleryApp:
|
|
308 |
except:
|
309 |
pass
|
310 |
|
311 |
-
|
312 |
# add safety check for some prompts
|
313 |
safety_check = True
|
314 |
unsafe_prompts = {}
|
@@ -398,44 +429,7 @@ if __name__ == '__main__':
|
|
398 |
login(token=os.environ.get("HF_TOKEN"))
|
399 |
st.set_page_config(layout="wide")
|
400 |
|
401 |
-
# if 'roster' not in st.session_state:
|
402 |
-
# print('loading roster')
|
403 |
-
# # st.session_state.roster = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferRoster', split='train'))
|
404 |
-
# st.session_state.roster = pd.DataFrame(load_from_disk(os.path.join(os.getcwd(), 'data', 'roster')))
|
405 |
-
# st.session_state.roster = st.session_state.roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name',
|
406 |
-
# 'model_download_count']].drop_duplicates().reset_index(drop=True)
|
407 |
-
# # add model download count from roster to promptbook dataframe
|
408 |
-
# if 'promptBook' not in st.session_state:
|
409 |
-
# print('loading promptBook')
|
410 |
-
#
|
411 |
-
# st.session_state.promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferMetadata', split='train'))
|
412 |
-
# # add 'checked' column to promptBook if not exist
|
413 |
-
# if 'checked' not in st.session_state.promptBook.columns:
|
414 |
-
# st.session_state.promptBook.loc[:, 'checked'] = False
|
415 |
-
#
|
416 |
-
# # add 'custom_score_weights' column to promptBook if not exist
|
417 |
-
# if 'weighted_score_sum' not in st.session_state.promptBook.columns:
|
418 |
-
# st.session_state.promptBook.loc[:, 'weighted_score_sum'] = 0
|
419 |
-
#
|
420 |
-
# st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
|
421 |
-
# # st.session_state.images = load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train', streaming=True)
|
422 |
-
# print(st.session_state.images)
|
423 |
-
# print('images loaded')
|
424 |
-
# # st.session_state.promptBook = pd.DataFrame(load_dataset('NYUSHPRP/ModelCofferPromptBook', split='train'))
|
425 |
-
# st.session_state.promptBook = st.session_state.promptBook.merge(st.session_state.roster[['model_id', 'model_name', 'modelVersion_id', 'modelVersion_name', 'model_download_count']], on=['model_id', 'modelVersion_id'], how='left')
|
426 |
-
#
|
427 |
-
# # add column to record current row index
|
428 |
-
# st.session_state.promptBook['row_idx'] = st.session_state.promptBook.index
|
429 |
-
# print('promptBook loaded')
|
430 |
-
# # print(st.session_state.promptBook)
|
431 |
-
#
|
432 |
-
# check_roster_error = False
|
433 |
-
# if check_roster_error:
|
434 |
-
# # print all rows with the same model_id and modelVersion_id but different model_download_count in roster
|
435 |
-
# print(st.session_state.roster[st.session_state.roster.duplicated(subset=['model_id', 'modelVersion_id'], keep=False)].sort_values(by=['model_id', 'modelVersion_id']))
|
436 |
roster, promptBook, images_ds = load_hf_dataset()
|
437 |
-
# if 'images' not in st.session_state:
|
438 |
-
# st.session_state.images = load_from_disk(os.path.join(os.getcwd(), 'data', 'promptbook'))
|
439 |
|
440 |
app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
|
441 |
app.app()
|
|
|
20 |
|
21 |
# hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["x"])
|
22 |
@st.cache_resource
|
23 |
+
def altair_histogram(hist_data, sort_by, mini, maxi):
|
24 |
brushed = alt.selection_interval(encodings=['x'], name="brushed")
|
25 |
+
|
26 |
+
chart = (
|
27 |
alt.Chart(hist_data)
|
28 |
+
.mark_bar(opacity=0.7, cornerRadius=2)
|
29 |
+
.encode(alt.X(f"{sort_by}:Q", bin=alt.Bin(maxbins=20)), y="count()")
|
30 |
+
# .add_selection(brushed)
|
31 |
+
# .properties(width=800, height=300)
|
32 |
+
)
|
33 |
+
|
34 |
+
# Create a transparent rectangle for highlighting the range
|
35 |
+
highlight = (
|
36 |
+
alt.Chart(pd.DataFrame({'x1': [mini], 'x2': [maxi]}))
|
37 |
+
.mark_rect(opacity=0.3)
|
38 |
+
.encode(x='x1', x2='x2')
|
39 |
+
# .properties(width=800, height=300)
|
40 |
)
|
41 |
|
42 |
+
# Layer the chart and the highlight rectangle
|
43 |
+
layered_chart = alt.layer(chart, highlight)
|
44 |
+
|
45 |
+
return layered_chart
|
46 |
+
|
47 |
+
# return (
|
48 |
+
# alt.Chart(hist_data)
|
49 |
+
# .mark_bar()
|
50 |
+
# .encode(alt.X(f"{sort_by}:Q", bin=alt.Bin(maxbins=20)), y="count()")
|
51 |
+
# .add_selection(brushed)
|
52 |
+
# .properties(width=600, height=300)
|
53 |
+
# )
|
54 |
+
|
55 |
class GalleryApp:
|
56 |
def __init__(self, promptBook, images_ds):
|
57 |
self.promptBook = promptBook
|
|
|
191 |
|
192 |
return items, info, col_num
|
193 |
|
|
|
194 |
def selection_panel_2(self, items):
|
195 |
selecters = st.columns([1, 5])
|
196 |
|
|
|
247 |
items = items[items['checked'] == True].reset_index(drop=True)
|
248 |
print(items)
|
249 |
|
250 |
+
# draw a distribution histogram
|
251 |
if sort_type == 'Scores':
|
252 |
+
with st.expander('Show score distribution histogram and select score range'):
|
253 |
+
st.write('**Score distribution histogram**')
|
254 |
+
chart_space = st.container()
|
255 |
+
# st.write('Select the range of scores to show')
|
256 |
+
hist_data = pd.DataFrame(items[sort_by])
|
257 |
+
mini = hist_data[sort_by].min().item()
|
258 |
+
maxi = hist_data[sort_by].max().item()
|
259 |
+
st.write('**Select the range of scores to show**')
|
260 |
+
r = st.slider('Select the range of scores to show', min_value=mini, max_value=maxi, value=(mini, maxi), label_visibility='collapsed')
|
261 |
+
with chart_space:
|
262 |
+
st.altair_chart(altair_histogram(hist_data, sort_by, r[0], r[1]), use_container_width=True)
|
263 |
+
# event_dict = altair_component(altair_chart=altair_histogram(hist_data, sort_by))
|
264 |
+
# r = event_dict.get(sort_by)
|
265 |
+
if r:
|
266 |
+
items = items[(items[sort_by] >= r[0]) & (items[sort_by] <= r[1])].reset_index(drop=True)
|
267 |
+
# st.write(r)
|
268 |
+
|
269 |
|
270 |
info = st.multiselect('Show Info',
|
271 |
['model_download_count', 'clip_score', 'avg_rank', 'model_name', 'model_id',
|
|
|
340 |
except:
|
341 |
pass
|
342 |
|
|
|
343 |
# add safety check for some prompts
|
344 |
safety_check = True
|
345 |
unsafe_prompts = {}
|
|
|
429 |
login(token=os.environ.get("HF_TOKEN"))
|
430 |
st.set_page_config(layout="wide")
|
431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
roster, promptBook, images_ds = load_hf_dataset()
|
|
|
|
|
433 |
|
434 |
app = GalleryApp(promptBook=promptBook, images_ds=images_ds)
|
435 |
app.app()
|
test_altair.py
CHANGED
@@ -1,50 +1,25 @@
|
|
1 |
-
import altair as alt
|
2 |
import streamlit as st
|
|
|
3 |
import pandas as pd
|
4 |
-
import numpy as np
|
5 |
-
|
6 |
-
from streamlit_vega_lite import vega_lite_component, altair_component, _component_func
|
7 |
-
|
8 |
-
hist_data = pd.DataFrame(np.random.normal(42, 10, (200, 1)), columns=["abc"])
|
9 |
-
print(hist_data)
|
10 |
-
|
11 |
-
@st.cache_resource
|
12 |
-
def altair_histogram():
|
13 |
-
brushed = alt.selection_interval(encodings=["x"], name="brushed")
|
14 |
-
|
15 |
-
return (
|
16 |
-
alt.Chart(hist_data)
|
17 |
-
.mark_bar()
|
18 |
-
.encode(alt.X("abc:Q", bin=True), y="count()")
|
19 |
-
.add_selection(brushed)
|
20 |
-
)
|
21 |
-
|
22 |
-
chart = altair_histogram()
|
23 |
-
res = st.altair_chart(chart, use_container_width=True)
|
24 |
-
# print(res)
|
25 |
-
event_dict = altair_component(altair_chart=altair_histogram())
|
26 |
-
chart_dict = chart.to_dict()
|
27 |
-
print(chart_dict)
|
28 |
-
altair_chart = chart.copy()
|
29 |
-
datasets = {}
|
30 |
-
|
31 |
-
def id_transform(data):
|
32 |
-
"""Altair data transformer that returns a fake named dataset with the
|
33 |
-
object id."""
|
34 |
-
name = f"d{id(data)}"
|
35 |
-
datasets[name] = data
|
36 |
-
return {"name": name}
|
37 |
-
|
38 |
-
alt.data_transformers.register("id", id_transform)
|
39 |
-
|
40 |
-
with alt.data_transformers.enable("id"):
|
41 |
-
chart_dict = altair_chart.to_dict()
|
42 |
-
# st.write(event_dict)
|
43 |
-
|
44 |
-
event_dict = _component_func(spec=chart_dict, **datasets, key=None, default={})
|
45 |
-
# print(chart_dict)
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import altair as alt
|
3 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
+
# Generate random data for the chart
|
6 |
+
data = pd.DataFrame({
|
7 |
+
'Category': ['A', 'B', 'C', 'D', 'E'],
|
8 |
+
'Value': [0.2, 0.5, 0.8, 1.2, 1.5]
|
9 |
+
})
|
10 |
+
|
11 |
+
# Define the color scale for the bars
|
12 |
+
color_scale = alt.Scale(
|
13 |
+
domain=[0, 1], # Values between 0 and 1 will be blue
|
14 |
+
range=['steelblue', 'lightgray']
|
15 |
+
)
|
16 |
+
|
17 |
+
# Create the bar chart using Altair
|
18 |
+
chart = alt.Chart(data).mark_bar().encode(
|
19 |
+
x='Category',
|
20 |
+
y='Value',
|
21 |
+
color=alt.Color('Value', scale=color_scale)
|
22 |
+
)
|
23 |
+
|
24 |
+
# Render the chart using Streamlit
|
25 |
+
st.altair_chart(chart, use_container_width=True)
|