Benjamin Bossan commited on
Commit
2e4eb94
·
1 Parent(s): 2cbbc23

Possibility to filter by multiple tags

Browse files

Also, demo now allows to select tags to filter on.

Files changed (4) hide show
  1. README.md +2 -2
  2. demo.py +29 -4
  3. src/gistillery/webservice.py +24 -10
  4. tests/test_app.py +54 -2
README.md CHANGED
@@ -94,8 +94,8 @@ python -m pytest tests/
94
 
95
  ```sh
96
  mypy src/
97
- black src/ && black tests/
98
- ruff src/ && ruff tests/
99
  ```
100
 
101
  ## TODOs
 
94
 
95
  ```sh
96
  mypy src/
97
+ black src/ && black tests/ && black demo.py
98
+ ruff src/ && ruff tests/ && ruff demo.py
99
  ```
100
 
101
  ## TODOs
demo.py CHANGED
@@ -3,6 +3,12 @@ import gradio as gr
3
 
4
 
5
  client = httpx.Client()
 
 
 
 
 
 
6
 
7
 
8
  def submit(inputs):
@@ -17,8 +23,12 @@ def check_status():
17
  return response.json()
18
 
19
 
20
- def get_results():
21
- response = httpx.get("http://localhost:8080/recent/")
 
 
 
 
22
  entries = response.json()
23
  texts: list[str] = []
24
  for i, entry in enumerate(entries, start=1):
@@ -31,21 +41,36 @@ def get_results():
31
  return "\n\n---\n\n".join(texts)
32
 
33
 
 
 
 
 
 
 
 
 
34
  def get_demo():
35
  with gr.Blocks() as demo:
36
  # submit new input
37
- inputs = gr.Textbox(lines=3, label="Input (text, URL)")
 
38
  btn_submit = gr.Button("Submit")
39
 
40
  # check job status
41
  gr.HTML(value=check_status, label="Status", every=3)
42
 
 
 
 
 
 
 
43
  # display output
44
  btn_output = gr.Button("Show results")
45
  output = gr.Markdown()
46
 
47
  btn_submit.click(submit, inputs=inputs)
48
- btn_output.click(get_results, outputs=[output])
49
 
50
  return demo
51
 
 
3
 
4
 
5
  client = httpx.Client()
6
+ # TODO: update the tags somehow; re-fetching inside of check_status doesn't work
7
+ tag_counts = {
8
+ key.strip("#"): val
9
+ for key, val in client.get("http://localhost:8080/tag_counts/").json().items()
10
+ if key.strip("#")
11
+ }
12
 
13
 
14
  def submit(inputs):
 
23
  return response.json()
24
 
25
 
26
+ def get_results(inputs: list[str]):
27
+ if not inputs:
28
+ response = httpx.get("http://localhost:8080/recent/")
29
+ else:
30
+ tags = [tag.split(" ", 1)[0] for tag in inputs]
31
+ response = httpx.get("http://localhost:8080/recent/" + ",".join(tags))
32
  entries = response.json()
33
  texts: list[str] = []
34
  for i, entry in enumerate(entries, start=1):
 
41
  return "\n\n---\n\n".join(texts)
42
 
43
 
44
+ INPUT_DESCRIPTION = """Input currently supports:
45
+ - plain text
46
+ - a URL to a webpage
47
+ - a URL to a youtube video (the video will be transcribed)
48
+ - a URL to an image (the image description will be used)
49
+ """
50
+
51
+
52
  def get_demo():
53
  with gr.Blocks() as demo:
54
  # submit new input
55
+ gr.Markdown(INPUT_DESCRIPTION)
56
+ inputs = gr.Textbox(lines=3, label="Input")
57
  btn_submit = gr.Button("Submit")
58
 
59
  # check job status
60
  gr.HTML(value=check_status, label="Status", every=3)
61
 
62
+ # check box of tags to filter on
63
+ tag_choices = sorted(f"{key} ({val})" for key, val in tag_counts.items())
64
+ tags = gr.CheckboxGroup(
65
+ tag_choices, label="Filter on tags (no selection = all)"
66
+ )
67
+
68
  # display output
69
  btn_output = gr.Button("Show results")
70
  output = gr.Markdown()
71
 
72
  btn_submit.click(submit, inputs=inputs)
73
+ btn_output.click(get_results, inputs=[tags], outputs=[output])
74
 
75
  return demo
76
 
src/gistillery/webservice.py CHANGED
@@ -87,7 +87,7 @@ def recent() -> list[EntriesResult]:
87
  JOIN tags t ON e.id = t.entry_id
88
  GROUP BY e.id
89
  ORDER BY e.created_at DESC
90
- LIMIT 10
91
  """)
92
  results = cursor.fetchall()
93
 
@@ -102,10 +102,10 @@ def recent() -> list[EntriesResult]:
102
 
103
  @app.get("/recent/{tag}")
104
  def recent_tag(tag: str) -> list[EntriesResult]:
105
- if not tag.startswith("#"):
106
- tag = "#" + tag
107
 
108
- # same as recent, but filter by tag
109
  with get_db_cursor() as cursor:
110
  cursor.execute(
111
  """
@@ -115,25 +115,39 @@ def recent_tag(tag: str) -> list[EntriesResult]:
115
  JOIN summaries s ON e.id = s.entry_id
116
  JOIN tags t ON e.id = t.entry_id
117
  WHERE e.id IN (
118
- SELECT entry_id FROM tags WHERE tag = ?
119
  )
120
  GROUP BY e.id
121
  ORDER BY e.created_at DESC
122
- LIMIT 10
123
- """,
124
- (tag,),
125
  )
126
  results = cursor.fetchall()
127
 
128
  entries = []
129
- for _id, author, summary, tags, date in results:
130
  entry = EntriesResult(
131
- id=_id, author=author, summary=summary, tags=tags.split(","), date=date
132
  )
133
  entries.append(entry)
134
  return entries
135
 
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  @app.get("/clear/")
138
  def clear() -> str:
139
  # clear all tables
 
87
  JOIN tags t ON e.id = t.entry_id
88
  GROUP BY e.id
89
  ORDER BY e.created_at DESC
90
+ LIMIT 100
91
  """)
92
  results = cursor.fetchall()
93
 
 
102
 
103
  @app.get("/recent/{tag}")
104
  def recent_tag(tag: str) -> list[EntriesResult]:
105
+ tags = tag.split(",")
106
+ tags = ["#" + tag for tag in tags if not tag.startswith("#")]
107
 
108
+ # same as recent, but filter by tags, where at least one tag matches
109
  with get_db_cursor() as cursor:
110
  cursor.execute(
111
  """
 
115
  JOIN summaries s ON e.id = s.entry_id
116
  JOIN tags t ON e.id = t.entry_id
117
  WHERE e.id IN (
118
+ SELECT entry_id FROM tags WHERE tag IN ({})
119
  )
120
  GROUP BY e.id
121
  ORDER BY e.created_at DESC
122
+ LIMIT 100
123
+ """.format(",".join("?" * len(tags))),
124
+ tags,
125
  )
126
  results = cursor.fetchall()
127
 
128
  entries = []
129
+ for _id, author, summary, tag, date in results:
130
  entry = EntriesResult(
131
+ id=_id, author=author, summary=summary, tags=tag.split(","), date=date
132
  )
133
  entries.append(entry)
134
  return entries
135
 
136
 
137
+ @app.get("/tag_counts/")
138
+ def tag_counts() -> dict[str, int]:
139
+ with get_db_cursor() as cursor:
140
+ cursor.execute("""
141
+ SELECT tag, COUNT(*) count
142
+ FROM tags
143
+ GROUP BY tag
144
+ ORDER BY count DESC
145
+ """)
146
+ results = cursor.fetchall()
147
+
148
+ return {tag: count for tag, count in results}
149
+
150
+
151
  @app.get("/clear/")
152
  def clear() -> str:
153
  # clear all tables
tests/test_app.py CHANGED
@@ -90,11 +90,10 @@ class TestWebservice:
90
  resp = client.get("/recent")
91
  assert resp.json() == []
92
 
93
- def test_recent_tag_empty(self, client):
94
  resp = client.get("/recent/general")
95
  assert resp.json() == []
96
 
97
- def test_submitted_job_status_pending(self, client, monkeypatch):
98
  # monkeypatch uuid4 to return a known value
99
  job_id = "abc1234"
100
  monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id))
@@ -238,6 +237,59 @@ class TestWebservice:
238
  assert resp0["summary"] == "this would"
239
  assert resp0["tags"] == sorted(["#this", "#would", "#be"])
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  def test_clear(self, client, cursor, registry):
242
  client.post("/submit", json={"author": "ben", "content": "this is a test"})
243
  self.process_jobs(registry)
 
90
  resp = client.get("/recent")
91
  assert resp.json() == []
92
 
93
+ def test_recent_tag_empty(self, client, monkeypatch):
94
  resp = client.get("/recent/general")
95
  assert resp.json() == []
96
 
 
97
  # monkeypatch uuid4 to return a known value
98
  job_id = "abc1234"
99
  monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id))
 
237
  assert resp0["summary"] == "this would"
238
  assert resp0["tags"] == sorted(["#this", "#would", "#be"])
239
 
240
+ def test_recent_multiple_entries(self, client, registry):
241
+ # submit 2 entries
242
+ client.post(
243
+ "/submit", json={"author": "maxi", "content": "aardvark ant antelope"}
244
+ )
245
+ client.post(
246
+ "/submit",
247
+ json={"author": "mini", "content": "bat bear bee"},
248
+ )
249
+ client.post(
250
+ "/submit",
251
+ json={"author": "mini", "content": "camel canary cat"},
252
+ )
253
+ self.process_jobs(registry)
254
+
255
+ # the "ant" tag is in only one entry
256
+ resp = client.get("/recent/ant").json()
257
+ assert len(resp) == 1
258
+
259
+ # "ant" and "bee" are in two entries
260
+ resp = client.get("/recent/ant,bee").json()
261
+ assert len(resp) == 2
262
+
263
+ # "ant" and "bee" and "cat" are in three entries
264
+ resp = client.get("/recent/cat,ant,bee").json()
265
+ assert len(resp) == 3
266
+
267
+ def test_tag_count(self, client, registry):
268
+ # submit 2 entries
269
+ client.post(
270
+ "/submit", json={"author": "ben", "content": "aardvark ant antelope"}
271
+ )
272
+ client.post(
273
+ "/submit",
274
+ json={"author": "ben", "content": "aardvark ant bat"},
275
+ )
276
+ client.post(
277
+ "/submit",
278
+ json={"author": "ben", "content": "aardvark camel canary"},
279
+ )
280
+ self.process_jobs(registry)
281
+
282
+ resp = client.get("/tag_counts").json()
283
+ expected = {
284
+ "#aardvark": 3,
285
+ "#ant": 2,
286
+ "#antelope": 1,
287
+ "#bat": 1,
288
+ "#camel": 1,
289
+ "#canary": 1,
290
+ }
291
+ assert resp == expected
292
+
293
  def test_clear(self, client, cursor, registry):
294
  client.post("/submit", json={"author": "ben", "content": "this is a test"})
295
  self.process_jobs(registry)