rbiswasfc commited on
Commit
8e97bca
1 Parent(s): 52b731a
Files changed (1) hide show
  1. app.py +135 -100
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import base64
 
2
  import os
3
  import shutil
4
  from collections import defaultdict
5
  from datetime import date, datetime, timedelta
6
- from io import BytesIO
7
 
8
  import dotenv
9
  import matplotlib.pyplot as plt
@@ -16,16 +17,12 @@ from fh_matplotlib import matplotlib2fasthtml
16
  from huggingface_hub import login, whoami
17
 
18
  dotenv.load_dotenv()
 
19
 
20
- style = Style("""
21
- .grid { margin-bottom: 1rem; }
22
- .card { display: flex; flex-direction: column; }
23
- .card img { margin-bottom: 0.5rem; }
24
- .card h5 { margin: 0; font-size: 0.9rem; line-height: 1.2; }
25
- .card a { color: inherit; text-decoration: none; }
26
- .card a:hover { text-decoration: underline; }
27
- """)
28
 
 
 
 
29
 
30
  # delete data folder
31
  if os.path.exists("data"):
@@ -34,18 +31,19 @@ if os.path.exists("data"):
34
  except OSError as e:
35
  print("Error: %s : %s" % ("data", e.strerror))
36
 
37
- app, rt = fast_app(html_style=(style,))
38
 
39
- login(token=os.environ.get("HF_TOKEN"))
 
 
 
40
 
41
  hf_user = whoami(os.environ.get("HF_TOKEN"))["name"]
42
  HF_REPO_ID_TXT = f"{hf_user}/zotero-answer-ai-texts"
43
  HF_REPO_ID_IMG = f"{hf_user}/zotero-answer-ai-images"
44
 
45
- abstract_ds = load_dataset(HF_REPO_ID_TXT, "abstracts", split="train")
46
- article_ds = load_dataset(HF_REPO_ID_TXT, "articles", split="train")
47
-
48
- image_ds = load_dataset(HF_REPO_ID_IMG, "images_first_page", split="train")
49
 
50
 
51
  def parse_date(date_string):
@@ -74,9 +72,116 @@ arxiv2image = {image["arxiv_id"]: image for image in image_ds}
74
 
75
  def get_article_details(arxiv_id):
76
  article = arxiv2article.get(arxiv_id, {})
77
- abstract = arxiv2abstract.get(arxiv_id, {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  image = arxiv2image.get(arxiv_id, {})
79
- return article, abstract, image
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  def generate_week_content(current_week):
@@ -105,7 +210,7 @@ def generate_week_content(current_week):
105
  articles = week2articles[current_week]
106
  article_cards = []
107
  for arxiv_id in articles:
108
- article, abstract, image = get_article_details(arxiv_id)
109
  article_title = article["contents"][0].get("paper_title", "article") if article["contents"] else "article"
110
 
111
  card_content = [
@@ -118,21 +223,18 @@ def generate_week_content(current_week):
118
  )
119
  ]
120
 
121
- if image:
122
- pil_image = image["image"] # image[0]["image"]
123
- pil_image.thumbnail((500, 500))
124
- img_byte_arr = BytesIO()
125
- pil_image.save(img_byte_arr, format="JPEG")
126
- img_byte_arr = img_byte_arr.getvalue()
127
- image_url = f"data:image/jpeg;base64,{base64.b64encode(img_byte_arr).decode('utf-8')}"
128
- card_content.insert(
129
- 0,
130
- Img(
131
- src=image_url,
132
- alt="Article image",
133
- style="max-width: 100%; height: auto; margin-bottom: 15px;",
134
- ),
135
- )
136
 
137
  article_cards.append(Card(*card_content, cls="mb-4"))
138
 
@@ -169,73 +271,6 @@ def get(date: str):
169
 
170
  @rt("/stats")
171
  async def get():
172
- @matplotlib2fasthtml
173
- def generate_chart():
174
- end_date = max(weeks)
175
- start_date = end_date - timedelta(weeks=11)
176
-
177
- dates = []
178
- counts = []
179
- current_date = start_date
180
- while current_date <= end_date:
181
- count = len(week2articles[current_date])
182
- date_str = current_date.strftime("%d-%B-%Y")
183
- dates.append(date_str)
184
- counts.append(count)
185
- current_date += timedelta(weeks=1)
186
-
187
- plt.figure(figsize=(12, 6))
188
- sns.set_style("darkgrid")
189
- # sns.set_palette("deep")
190
-
191
- ax = sns.barplot(x=dates, y=counts)
192
-
193
- plt.title("Papers per Week (Last 12 Weeks)", fontsize=16, fontweight="bold")
194
- plt.xlabel("Week", fontsize=12)
195
- plt.ylabel("Number of Papers", fontsize=12)
196
-
197
- # Rotate and align the tick labels so they look better
198
- plt.xticks(rotation=45, ha="right")
199
-
200
- # Use a tight layout to prevent the labels from being cut off
201
- plt.tight_layout()
202
-
203
- # Add value labels on top of each bar
204
- for i, v in enumerate(counts):
205
- ax.text(i, v + 0.5, str(v), ha="center", va="bottom")
206
-
207
- # Increase y-axis limit slightly to accommodate the value labels
208
- plt.ylim(0, max(counts) * 1.1)
209
-
210
- @matplotlib2fasthtml
211
- def generate_contributions_chart():
212
- article_df = article_ds.data.to_pandas()
213
- added_by_df = article_df.groupby("added_by").size().reset_index(name="count")
214
- added_by_df = added_by_df.sort_values("count", ascending=False) # Ascending for bottom-to-top order
215
-
216
- plt.figure(figsize=(12, 8))
217
- sns.set_style("darkgrid")
218
- sns.set_palette("deep")
219
-
220
- ax = sns.barplot(x="count", y="added_by", data=added_by_df)
221
-
222
- plt.title("Upload Counts", fontsize=16, fontweight="bold")
223
- plt.xlabel("Number of Articles Added", fontsize=12)
224
- plt.ylabel("User", fontsize=12)
225
-
226
- # Add value labels to the end of each bar
227
- for i, v in enumerate(added_by_df["count"]):
228
- ax.text(v + 0.5, i, str(v), va="center")
229
-
230
- # Adjust x-axis to make room for labels
231
- plt.xlim(0, max(added_by_df["count"]) * 1.1)
232
-
233
- plt.tight_layout()
234
-
235
- # chart = Div(generate_chart(), id="chart")
236
- bar_chart = Div(generate_chart(), id="bar-chart")
237
- pie_chart = Div(generate_contributions_chart(), id="pie-chart")
238
-
239
  # add contributions
240
  article_df = article_ds.data.to_pandas()
241
  added_by_df = article_df.groupby("added_by").size().reset_index(name="count")
 
1
  import base64
2
+ import io
3
  import os
4
  import shutil
5
  from collections import defaultdict
6
  from datetime import date, datetime, timedelta
7
+ from functools import lru_cache
8
 
9
  import dotenv
10
  import matplotlib.pyplot as plt
 
17
  from huggingface_hub import login, whoami
18
 
19
  dotenv.load_dotenv()
20
+ login(token=os.environ.get("HF_TOKEN"))
21
 
 
 
 
 
 
 
 
 
22
 
23
+ PLACEHOLDER_IMAGE = (
24
+ "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII="
25
+ )
26
 
27
  # delete data folder
28
  if os.path.exists("data"):
 
31
  except OSError as e:
32
  print("Error: %s : %s" % ("data", e.strerror))
33
 
 
34
 
35
+ @lru_cache(maxsize=None)
36
+ def load_cached_dataset(repo_id, dataset_name, split):
37
+ return load_dataset(repo_id, dataset_name, split=split)
38
+
39
 
40
  hf_user = whoami(os.environ.get("HF_TOKEN"))["name"]
41
  HF_REPO_ID_TXT = f"{hf_user}/zotero-answer-ai-texts"
42
  HF_REPO_ID_IMG = f"{hf_user}/zotero-answer-ai-images"
43
 
44
+ abstract_ds = load_cached_dataset(HF_REPO_ID_TXT, "abstracts", "train")
45
+ article_ds = load_cached_dataset(HF_REPO_ID_TXT, "articles", "train")
46
+ image_ds = load_cached_dataset(HF_REPO_ID_IMG, "images_first_page", "train")
 
47
 
48
 
49
  def parse_date(date_string):
 
72
 
73
  def get_article_details(arxiv_id):
74
  article = arxiv2article.get(arxiv_id, {})
75
+ # abstract = arxiv2abstract.get(arxiv_id, {})
76
+ # image = arxiv2image.get(arxiv_id, {})
77
+ return article
78
+
79
+
80
+ # stats --
81
+ @matplotlib2fasthtml
82
+ def generate_chart():
83
+ end_date = max(weeks)
84
+ start_date = end_date - timedelta(weeks=11)
85
+
86
+ dates = []
87
+ counts = []
88
+ current_date = start_date
89
+ while current_date <= end_date:
90
+ count = len(week2articles[current_date])
91
+ date_str = current_date.strftime("%d-%B-%Y")
92
+ dates.append(date_str)
93
+ counts.append(count)
94
+ current_date += timedelta(weeks=1)
95
+
96
+ plt.figure(figsize=(12, 6))
97
+ sns.set_style("darkgrid")
98
+ # sns.set_palette("deep")
99
+
100
+ ax = sns.barplot(x=dates, y=counts)
101
+
102
+ plt.title("Papers per Week (Last 12 Weeks)", fontsize=16, fontweight="bold")
103
+ plt.xlabel("Week", fontsize=12)
104
+ plt.ylabel("Number of Papers", fontsize=12)
105
+
106
+ # Rotate and align the tick labels so they look better
107
+ plt.xticks(rotation=45, ha="right")
108
+
109
+ # Use a tight layout to prevent the labels from being cut off
110
+ plt.tight_layout()
111
+
112
+ # Add value labels on top of each bar
113
+ for i, v in enumerate(counts):
114
+ ax.text(i, v + 0.5, str(v), ha="center", va="bottom")
115
+
116
+ # Increase y-axis limit slightly to accommodate the value labels
117
+ plt.ylim(0, max(counts) * 1.1)
118
+
119
+
120
+ @matplotlib2fasthtml
121
+ def generate_contributions_chart():
122
+ article_df = article_ds.data.to_pandas()
123
+ added_by_df = article_df.groupby("added_by").size().reset_index(name="count")
124
+ added_by_df = added_by_df.sort_values("count", ascending=False) # Ascending for bottom-to-top order
125
+
126
+ plt.figure(figsize=(12, 8))
127
+ sns.set_style("darkgrid")
128
+ sns.set_palette("deep")
129
+
130
+ ax = sns.barplot(x="count", y="added_by", data=added_by_df)
131
+
132
+ plt.title("Upload Counts", fontsize=16, fontweight="bold")
133
+ plt.xlabel("Number of Articles Added", fontsize=12)
134
+ plt.ylabel("User", fontsize=12)
135
+
136
+ # Add value labels to the end of each bar
137
+ for i, v in enumerate(added_by_df["count"]):
138
+ ax.text(v + 0.5, i, str(v), va="center")
139
+
140
+ # Adjust x-axis to make room for labels
141
+ plt.xlim(0, max(added_by_df["count"]) * 1.1)
142
+
143
+ plt.tight_layout()
144
+
145
+
146
+ # chart = Div(generate_chart(), id="chart")
147
+ bar_chart = Div(generate_chart(), id="bar-chart")
148
+ pie_chart = Div(generate_contributions_chart(), id="pie-chart")
149
+
150
+ #### fasthtml app ####
151
+ style = Style("""
152
+ .grid { margin-bottom: 1rem; }
153
+ .card { display: flex; flex-direction: column; }
154
+ .card img { margin-bottom: 0.5rem; }
155
+ .card h5 { margin: 0; font-size: 0.9rem; line-height: 1.2; }
156
+ .card a { color: inherit; text-decoration: none; }
157
+ .card a:hover { text-decoration: underline; }
158
+ """)
159
+
160
+ app, rt = fast_app(html_style=(style,))
161
+
162
+
163
+ # Image ---
164
+ def optimize_image(pil_image, max_size=(500, 500), quality=85):
165
+ img_byte_arr = io.BytesIO()
166
+ pil_image.thumbnail(max_size)
167
+ pil_image.save(img_byte_arr, format="JPEG", quality=quality, optimize=True)
168
+ return f"data:image/jpeg;base64,{base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')}"
169
+
170
+
171
+ @lru_cache(maxsize=100)
172
+ def get_optimized_image(arxiv_id):
173
  image = arxiv2image.get(arxiv_id, {})
174
+ if image:
175
+ return optimize_image(image["image"])
176
+ return None
177
+
178
+
179
+ @rt("/image/{arxiv_id}")
180
+ def get(arxiv_id: str):
181
+ image_url = get_optimized_image(arxiv_id)
182
+ if image_url:
183
+ return Img(src=image_url, alt="Article image", style="max-width: 100%; height: auto; margin-bottom: 15px;")
184
+ return ""
185
 
186
 
187
  def generate_week_content(current_week):
 
210
  articles = week2articles[current_week]
211
  article_cards = []
212
  for arxiv_id in articles:
213
+ article = get_article_details(arxiv_id)
214
  article_title = article["contents"][0].get("paper_title", "article") if article["contents"] else "article"
215
 
216
  card_content = [
 
223
  )
224
  ]
225
 
226
+ # insert image
227
+ card_content.insert(
228
+ 0,
229
+ Img(
230
+ src=PLACEHOLDER_IMAGE, # image_url,
231
+ alt="Article image",
232
+ style="max-width: 100%; height: auto; margin-bottom: 15px;",
233
+ hx_get=f"/image/{arxiv_id}",
234
+ hx_trigger="revealed",
235
+ hx_swap="outerHTML",
236
+ ),
237
+ )
 
 
 
238
 
239
  article_cards.append(Card(*card_content, cls="mb-4"))
240
 
 
271
 
272
  @rt("/stats")
273
  async def get():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  # add contributions
275
  article_df = article_ds.data.to_pandas()
276
  added_by_df = article_df.groupby("added_by").size().reset_index(name="count")