Spaces:
Sleeping
Sleeping
import base64 | |
import io | |
import os | |
import shutil | |
from collections import defaultdict | |
from datetime import date, datetime, timedelta | |
from functools import lru_cache | |
import dotenv | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from datasets import load_dataset | |
from dateutil.parser import parse | |
from dateutil.tz import tzutc | |
from fasthtml.common import * | |
from fh_matplotlib import matplotlib2fasthtml | |
from huggingface_hub import login, whoami | |
dotenv.load_dotenv() | |
login(token=os.environ.get("HF_TOKEN")) | |
PLACEHOLDER_IMAGE = ( | |
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" | |
) | |
# delete data folder | |
if os.path.exists("data"): | |
try: | |
shutil.rmtree("data") | |
except OSError as e: | |
print("Error: %s : %s" % ("data", e.strerror)) | |
def load_cached_dataset(repo_id, dataset_name, split): | |
return load_dataset(repo_id, dataset_name, split=split) | |
hf_user = whoami(os.environ.get("HF_TOKEN"))["name"] | |
HF_REPO_ID_TXT = f"{hf_user}/zotero-answer-ai-texts" | |
HF_REPO_ID_IMG = f"{hf_user}/zotero-answer-ai-images" | |
abstract_ds = load_cached_dataset(HF_REPO_ID_TXT, "abstracts", "train") | |
article_ds = load_cached_dataset(HF_REPO_ID_TXT, "articles", "train") | |
image_ds = load_cached_dataset(HF_REPO_ID_IMG, "images_first_page", "train") | |
def parse_date(date_string): | |
try: | |
return parse(date_string).astimezone(tzutc()).date() | |
except ValueError: | |
return date.today() | |
def get_week_start(date_obj): | |
return date_obj - timedelta(days=date_obj.weekday()) | |
week2articles = defaultdict(list) | |
for article in article_ds: | |
date_added = parse_date(article["date_added"]) | |
week_start = get_week_start(date_added) | |
week2articles[week_start].append(article["arxiv_id"]) | |
weeks = sorted(week2articles.keys(), reverse=True) | |
arxiv2article = {article["arxiv_id"]: article for article in article_ds} | |
arxiv2abstract = {abstract["arxiv_id"]: abstract for abstract in abstract_ds} | |
arxiv2image = {image["arxiv_id"]: image for image in image_ds} | |
def get_article_details(arxiv_id): | |
article = arxiv2article.get(arxiv_id, {}) | |
# abstract = arxiv2abstract.get(arxiv_id, {}) | |
# image = arxiv2image.get(arxiv_id, {}) | |
return article | |
# stats -- | |
def generate_chart(): | |
end_date = max(weeks) | |
start_date = end_date - timedelta(weeks=23) | |
dates = [] | |
counts = [] | |
current_date = start_date | |
while current_date <= end_date: | |
count = len(week2articles[current_date]) | |
date_str = current_date.strftime("%d-%B-%Y") | |
dates.append(date_str) | |
counts.append(count) | |
current_date += timedelta(weeks=1) | |
plt.figure(figsize=(12, 6)) | |
sns.set_style("darkgrid") | |
# sns.set_palette("deep") | |
ax = sns.barplot(x=dates, y=counts) | |
plt.title("Papers per Week (Last 24 Weeks)", fontsize=16, fontweight="bold") | |
plt.xlabel("Week", fontsize=12) | |
plt.ylabel("Number of Papers", fontsize=12) | |
# Rotate and align the tick labels so they look better | |
plt.xticks(rotation=45, ha="right") | |
# Use a tight layout to prevent the labels from being cut off | |
plt.tight_layout() | |
# Add value labels on top of each bar | |
for i, v in enumerate(counts): | |
ax.text(i, v + 0.5, str(v), ha="center", va="bottom") | |
# Increase y-axis limit slightly to accommodate the value labels | |
plt.ylim(0, max(counts) * 1.1) | |
def generate_contributions_chart(): | |
article_df = article_ds.data.to_pandas() | |
added_by_df = article_df.groupby("added_by").size().reset_index(name="count") | |
added_by_df = added_by_df.sort_values("count", ascending=False) # Ascending for bottom-to-top order | |
plt.figure(figsize=(12, 8)) | |
sns.set_style("darkgrid") | |
sns.set_palette("deep") | |
ax = sns.barplot(x="count", y="added_by", data=added_by_df) | |
plt.title("Upload Counts", fontsize=16, fontweight="bold") | |
plt.xlabel("Number of Articles Added", fontsize=12) | |
plt.ylabel("User", fontsize=12) | |
# Add value labels to the end of each bar | |
for i, v in enumerate(added_by_df["count"]): | |
ax.text(v + 0.5, i, str(v), va="center") | |
# Adjust x-axis to make room for labels | |
plt.xlim(0, max(added_by_df["count"]) * 1.1) | |
plt.tight_layout() | |
# chart = Div(generate_chart(), id="chart") | |
bar_chart = Div(generate_chart(), id="bar-chart") | |
pie_chart = Div(generate_contributions_chart(), id="pie-chart") | |
#### fasthtml app #### | |
style = Style(""" | |
.grid { margin-bottom: 1rem; } | |
.card { display: flex; flex-direction: column; } | |
.card img { margin-bottom: 0.5rem; width: 500px; height: 500px; object-fit: cover; } | |
.card img { margin-bottom: 0.5rem; } | |
.card h5 { margin: 0; font-size: 0.9rem; line-height: 1.2; } | |
.card a { color: inherit; text-decoration: none; } | |
.card a:hover { text-decoration: underline; } | |
.htmx-indicator { display: none; } | |
.htmx-request .htmx-indicator { display: inline; } | |
.htmx-request.htmx-indicator { display: inline; } | |
""") | |
app, rt = fast_app(html_style=(style,)) | |
# Image --- | |
def optimize_image(pil_image, max_size=(500, 500), quality=85): | |
img_byte_arr = io.BytesIO() | |
pil_image.thumbnail(max_size) | |
pil_image.save(img_byte_arr, format="JPEG", quality=quality, optimize=True) | |
return f"data:image/jpeg;base64,{base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')}" | |
def get_optimized_image(arxiv_id): | |
image = arxiv2image.get(arxiv_id, {}) | |
if image: | |
return optimize_image(image["image"]) | |
return None | |
def get(arxiv_id: str): | |
image_url = get_optimized_image(arxiv_id) | |
if image_url: | |
return Img(src=image_url, alt="Article image", style="max-width: 100%; height: auto; margin-bottom: 15px;") | |
return "" | |
def generate_week_content(current_week): | |
week_index = weeks.index(current_week) | |
prev_week = weeks[week_index + 1] if week_index < len(weeks) - 1 else None | |
next_week = weeks[week_index - 1] if week_index > 0 else None | |
nav_buttons = Div( | |
Button( | |
"β Previous Week", | |
hx_get=f"/week/{prev_week}" if prev_week else "#", | |
hx_target="#content", | |
hx_swap="innerHTML", | |
disabled=not prev_week, | |
), | |
Button( | |
"Next Week β", | |
hx_get=f"/week/{next_week}" if next_week else "#", | |
hx_target="#content", | |
hx_swap="innerHTML", | |
disabled=not next_week, | |
), | |
A("View Stats", href="/stats", cls="button"), | |
) | |
articles = week2articles[current_week] | |
article_cards = [] | |
for arxiv_id in articles: | |
article = get_article_details(arxiv_id) | |
article_title = article["contents"][0].get("paper_title", "article") if article["contents"] else "article" | |
card_content = [ | |
H5( | |
A( | |
article_title, | |
href=f"https://arxiv.org/abs/{arxiv_id}", | |
target="_blank", | |
) | |
) | |
] | |
card_content.insert( | |
0, | |
Div( | |
Img(src=PLACEHOLDER_IMAGE, alt="Article image", style="width: 500px; height: 500px; object-fit: cover;"), | |
Img( | |
src="/static/loading.gif", | |
alt="Loading", | |
cls="htmx-indicator", | |
style="position: absolute; top: 50%; left: 50%; transform: translate(-50%, -50%);", | |
), | |
style="position: relative;", | |
hx_get=f"/image/{arxiv_id}", | |
hx_trigger="revealed", | |
hx_swap="innerHTML", | |
), | |
) | |
article_cards.append(Card(*card_content, cls="mb-4")) | |
grid = Grid( | |
*article_cards, | |
style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 1rem;", | |
) | |
week_end = current_week + timedelta(days=6) | |
return Div( | |
nav_buttons, | |
Br(), | |
H5(f"{current_week.strftime('%B %d')} - {week_end.strftime('%B %d, %Y')} ({len(articles)} articles)"), | |
Br(), | |
grid, | |
nav_buttons, | |
id="content", | |
) | |
def get(): | |
return Titled("AnswerAI Zotero Weekly", generate_week_content(weeks[0])) | |
def get(date: str): | |
try: | |
current_week = datetime.strptime(date, "%Y-%m-%d").date() | |
return generate_week_content(current_week) | |
except Exception as e: | |
return Div(f"Error displaying articles: {str(e)}") | |
async def get(): | |
# add contributions | |
article_df = article_ds.data.to_pandas() | |
added_by_df = article_df.groupby("added_by").size().reset_index(name="count") | |
added_by_df = added_by_df.sort_values("count", ascending=False) | |
return Titled( | |
"AnswerAI Zotero Stats", | |
H5("Papers per Week (Last 12 Weeks)"), | |
bar_chart, | |
Br(), | |
H5("Contributions by User"), | |
pie_chart, | |
Br(), | |
A("Back to Weekly View", href="/", cls="button"), | |
) | |
# serve() | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) | |