zotero-weekly / app.py
rbiswasfc's picture
stats
657b752
import base64
import io
import os
import shutil
from collections import defaultdict
from datetime import date, datetime, timedelta
from functools import lru_cache
import dotenv
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset
from dateutil.parser import parse
from dateutil.tz import tzutc
from fasthtml.common import *
from fh_matplotlib import matplotlib2fasthtml
from huggingface_hub import login, whoami
dotenv.load_dotenv()
login(token=os.environ.get("HF_TOKEN"))
PLACEHOLDER_IMAGE = (
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII="
)
# delete data folder
if os.path.exists("data"):
try:
shutil.rmtree("data")
except OSError as e:
print("Error: %s : %s" % ("data", e.strerror))
@lru_cache(maxsize=None)
def load_cached_dataset(repo_id, dataset_name, split):
return load_dataset(repo_id, dataset_name, split=split)
hf_user = whoami(os.environ.get("HF_TOKEN"))["name"]
HF_REPO_ID_TXT = f"{hf_user}/zotero-answer-ai-texts"
HF_REPO_ID_IMG = f"{hf_user}/zotero-answer-ai-images"
abstract_ds = load_cached_dataset(HF_REPO_ID_TXT, "abstracts", "train")
article_ds = load_cached_dataset(HF_REPO_ID_TXT, "articles", "train")
image_ds = load_cached_dataset(HF_REPO_ID_IMG, "images_first_page", "train")
def parse_date(date_string):
try:
return parse(date_string).astimezone(tzutc()).date()
except ValueError:
return date.today()
def get_week_start(date_obj):
return date_obj - timedelta(days=date_obj.weekday())
week2articles = defaultdict(list)
for article in article_ds:
date_added = parse_date(article["date_added"])
week_start = get_week_start(date_added)
week2articles[week_start].append(article["arxiv_id"])
weeks = sorted(week2articles.keys(), reverse=True)
arxiv2article = {article["arxiv_id"]: article for article in article_ds}
arxiv2abstract = {abstract["arxiv_id"]: abstract for abstract in abstract_ds}
arxiv2image = {image["arxiv_id"]: image for image in image_ds}
def get_article_details(arxiv_id):
article = arxiv2article.get(arxiv_id, {})
# abstract = arxiv2abstract.get(arxiv_id, {})
# image = arxiv2image.get(arxiv_id, {})
return article
# stats --
@matplotlib2fasthtml
def generate_chart():
end_date = max(weeks)
start_date = end_date - timedelta(weeks=23)
dates = []
counts = []
current_date = start_date
while current_date <= end_date:
count = len(week2articles[current_date])
date_str = current_date.strftime("%d-%B-%Y")
dates.append(date_str)
counts.append(count)
current_date += timedelta(weeks=1)
plt.figure(figsize=(12, 6))
sns.set_style("darkgrid")
# sns.set_palette("deep")
ax = sns.barplot(x=dates, y=counts)
plt.title("Papers per Week (Last 24 Weeks)", fontsize=16, fontweight="bold")
plt.xlabel("Week", fontsize=12)
plt.ylabel("Number of Papers", fontsize=12)
# Rotate and align the tick labels so they look better
plt.xticks(rotation=45, ha="right")
# Use a tight layout to prevent the labels from being cut off
plt.tight_layout()
# Add value labels on top of each bar
for i, v in enumerate(counts):
ax.text(i, v + 0.5, str(v), ha="center", va="bottom")
# Increase y-axis limit slightly to accommodate the value labels
plt.ylim(0, max(counts) * 1.1)
@matplotlib2fasthtml
def generate_contributions_chart():
article_df = article_ds.data.to_pandas()
added_by_df = article_df.groupby("added_by").size().reset_index(name="count")
added_by_df = added_by_df.sort_values("count", ascending=False) # Ascending for bottom-to-top order
plt.figure(figsize=(12, 8))
sns.set_style("darkgrid")
sns.set_palette("deep")
ax = sns.barplot(x="count", y="added_by", data=added_by_df)
plt.title("Upload Counts", fontsize=16, fontweight="bold")
plt.xlabel("Number of Articles Added", fontsize=12)
plt.ylabel("User", fontsize=12)
# Add value labels to the end of each bar
for i, v in enumerate(added_by_df["count"]):
ax.text(v + 0.5, i, str(v), va="center")
# Adjust x-axis to make room for labels
plt.xlim(0, max(added_by_df["count"]) * 1.1)
plt.tight_layout()
# chart = Div(generate_chart(), id="chart")
bar_chart = Div(generate_chart(), id="bar-chart")
pie_chart = Div(generate_contributions_chart(), id="pie-chart")
#### fasthtml app ####
style = Style("""
.grid { margin-bottom: 1rem; }
.card { display: flex; flex-direction: column; }
.card img { margin-bottom: 0.5rem; width: 500px; height: 500px; object-fit: cover; }
.card img { margin-bottom: 0.5rem; }
.card h5 { margin: 0; font-size: 0.9rem; line-height: 1.2; }
.card a { color: inherit; text-decoration: none; }
.card a:hover { text-decoration: underline; }
.htmx-indicator { display: none; }
.htmx-request .htmx-indicator { display: inline; }
.htmx-request.htmx-indicator { display: inline; }
""")
app, rt = fast_app(html_style=(style,))
# Image ---
def optimize_image(pil_image, max_size=(500, 500), quality=85):
img_byte_arr = io.BytesIO()
pil_image.thumbnail(max_size)
pil_image.save(img_byte_arr, format="JPEG", quality=quality, optimize=True)
return f"data:image/jpeg;base64,{base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')}"
@lru_cache(maxsize=100)
def get_optimized_image(arxiv_id):
image = arxiv2image.get(arxiv_id, {})
if image:
return optimize_image(image["image"])
return None
@rt("/image/{arxiv_id}")
def get(arxiv_id: str):
image_url = get_optimized_image(arxiv_id)
if image_url:
return Img(src=image_url, alt="Article image", style="max-width: 100%; height: auto; margin-bottom: 15px;")
return ""
def generate_week_content(current_week):
week_index = weeks.index(current_week)
prev_week = weeks[week_index + 1] if week_index < len(weeks) - 1 else None
next_week = weeks[week_index - 1] if week_index > 0 else None
nav_buttons = Div(
Button(
"← Previous Week",
hx_get=f"/week/{prev_week}" if prev_week else "#",
hx_target="#content",
hx_swap="innerHTML",
disabled=not prev_week,
),
Button(
"Next Week β†’",
hx_get=f"/week/{next_week}" if next_week else "#",
hx_target="#content",
hx_swap="innerHTML",
disabled=not next_week,
),
A("View Stats", href="/stats", cls="button"),
)
articles = week2articles[current_week]
article_cards = []
for arxiv_id in articles:
article = get_article_details(arxiv_id)
article_title = article["contents"][0].get("paper_title", "article") if article["contents"] else "article"
card_content = [
H5(
A(
article_title,
href=f"https://arxiv.org/abs/{arxiv_id}",
target="_blank",
)
)
]
card_content.insert(
0,
Div(
Img(src=PLACEHOLDER_IMAGE, alt="Article image", style="width: 500px; height: 500px; object-fit: cover;"),
Img(
src="/static/loading.gif",
alt="Loading",
cls="htmx-indicator",
style="position: absolute; top: 50%; left: 50%; transform: translate(-50%, -50%);",
),
style="position: relative;",
hx_get=f"/image/{arxiv_id}",
hx_trigger="revealed",
hx_swap="innerHTML",
),
)
article_cards.append(Card(*card_content, cls="mb-4"))
grid = Grid(
*article_cards,
style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 1rem;",
)
week_end = current_week + timedelta(days=6)
return Div(
nav_buttons,
Br(),
H5(f"{current_week.strftime('%B %d')} - {week_end.strftime('%B %d, %Y')} ({len(articles)} articles)"),
Br(),
grid,
nav_buttons,
id="content",
)
@rt("/")
def get():
return Titled("AnswerAI Zotero Weekly", generate_week_content(weeks[0]))
@rt("/week/{date}")
def get(date: str):
try:
current_week = datetime.strptime(date, "%Y-%m-%d").date()
return generate_week_content(current_week)
except Exception as e:
return Div(f"Error displaying articles: {str(e)}")
@rt("/stats")
async def get():
# add contributions
article_df = article_ds.data.to_pandas()
added_by_df = article_df.groupby("added_by").size().reset_index(name="count")
added_by_df = added_by_df.sort_values("count", ascending=False)
return Titled(
"AnswerAI Zotero Stats",
H5("Papers per Week (Last 12 Weeks)"),
bar_chart,
Br(),
H5("Contributions by User"),
pie_chart,
Br(),
A("Back to Weekly View", href="/", cls="button"),
)
# serve()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))