Benjamin Bossan commited on
Commit
4c2b75c
·
1 Parent(s): 111044d

Add pdf processor using pypdf

Browse files

Also, enforce a max length to the processed output.

demo.py CHANGED
@@ -72,6 +72,9 @@ Input currently supports:
72
  - a URL to a webpage
73
  - a URL to a youtube video (the video will be transcribed)
74
  - a URL to an image (url ending in .png, .jpg, etc.; the image description will be used)
 
 
 
75
  """
76
 
77
 
 
72
  - a URL to a webpage
73
  - a URL to a youtube video (the video will be transcribed)
74
  - a URL to an image (url ending in .png, .jpg, etc.; the image description will be used)
75
+ - a URL to a PDF (url ending in .pdf, e.g. https://arxiv.org/pdf/2108.12409.pdf)
76
+
77
+ Long inputs will be truncated.
78
  """
79
 
80
 
requests.org CHANGED
@@ -62,6 +62,17 @@ curl -X 'POST' \
62
  }'
63
  #+end_src
64
 
 
 
 
 
 
 
 
 
 
 
 
65
  #+begin_src bash
66
  curl -X 'GET' \
67
  'http://localhost:8080/check_job_status/' \
 
62
  }'
63
  #+end_src
64
 
65
+ #+begin_src bash
66
+ curl -X 'POST' \
67
+ 'http://localhost:8080/submit/' \
68
+ -H 'accept: application/json' \
69
+ -H 'Content-Type: application/json' \
70
+ -d '{
71
+ "author": "ben",
72
+ "content": "https://arxiv.org/pdf/2108.12409.pdf"
73
+ }'
74
+ #+end_src
75
+
76
  #+begin_src bash
77
  curl -X 'GET' \
78
  'http://localhost:8080/check_job_status/' \
requirements.txt CHANGED
@@ -10,3 +10,4 @@ pillow
10
  gradio
11
  urllib3
12
  pytube
 
 
10
  gradio
11
  urllib3
12
  pytube
13
+ pypdf
src/gistillery/config.py CHANGED
@@ -2,14 +2,24 @@ import os
2
  from pathlib import Path
3
 
4
  from pydantic import BaseSettings
 
5
 
6
 
7
  class Config(BaseSettings):
8
  hf_hub_token: str = "missing"
9
  hf_agent: str = "https://api-inference.huggingface.co/models/bigcode/starcoder"
10
  db_file_name: Path = Path("sqlite-data.db")
11
- sampling_rate: int = 16_000 # audio transcription
12
- max_yt_length: int = 1800 # in minutes
 
 
 
 
 
 
 
 
 
13
 
14
  class Config:
15
  # load .env file by default, with provisio to use other .env files if set
 
2
  from pathlib import Path
3
 
4
  from pydantic import BaseSettings
5
+ from pydantic.types import PositiveInt
6
 
7
 
8
  class Config(BaseSettings):
9
  hf_hub_token: str = "missing"
10
  hf_agent: str = "https://api-inference.huggingface.co/models/bigcode/starcoder"
11
  db_file_name: Path = Path("sqlite-data.db")
12
+ processing_max_length: PositiveInt = 10000 # in characters
13
+ sampling_rate: PositiveInt = 16_000 # audio transcription
14
+ max_yt_length: PositiveInt = 1800 # in minutes
15
+ pdf_stop_words: list[str] = [
16
+ "ACKNOWLEDGMENTS",
17
+ "Acknowledgments",
18
+ "acknowledgments",
19
+ "REFERENCES",
20
+ "References",
21
+ "references",
22
+ ]
23
 
24
  class Config:
25
  # load .env file by default, with provisio to use other .env files if set
src/gistillery/errors.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ class ProcessingError(RuntimeError):
2
+ """Error raised when processing goes wrong"""
src/gistillery/preprocessing.py CHANGED
@@ -13,6 +13,7 @@ from transformers import AutoProcessor, WhisperForConditionalGeneration
13
 
14
  from gistillery.base import JobInput
15
  from gistillery.config import get_config
 
16
  from gistillery.media import download_yt_audio, load_audio
17
  from gistillery.tools import get_agent
18
 
@@ -32,13 +33,29 @@ def get_url(text: str) -> str | None:
32
 
33
 
34
  class Processor(abc.ABC):
 
 
 
 
35
  def get_name(self) -> str:
36
  return self.__class__.__name__
37
 
38
  def __call__(self, job: JobInput) -> str:
 
 
 
 
 
 
39
  _id = job.id
40
  logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
41
  result = self.process(job)
 
 
 
 
 
 
42
  logger.info(f"Finished processing input (id={_id[:8]})")
43
  return result
44
 
@@ -61,6 +78,7 @@ class RawTextProcessor(Processor):
61
 
62
  class DefaultUrlProcessor(Processor):
63
  def __init__(self) -> None:
 
64
  self.client = Client()
65
  self.url = Optional[str]
66
  self.template = "{url}\n\n{content}"
@@ -85,8 +103,48 @@ class DefaultUrlProcessor(Processor):
85
  return str(text)
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  class ImageUrlProcessor(Processor):
89
  def __init__(self) -> None:
 
90
  self.client = Client()
91
  self.url = Optional[str]
92
  self.template = "{url}\n\n{content}"
@@ -111,13 +169,15 @@ class ImageUrlProcessor(Processor):
111
  response = self.client.get(self.url)
112
  image = Image.open(io.BytesIO(response.content)).convert('RGB')
113
  caption = get_agent().run("Caption the following image", image=image)
114
- return str(caption)
 
115
 
116
 
117
  class YoutubeUrlProcessor(Processor):
118
  """Download yt audio, transcribe with whisper"""
119
 
120
  def __init__(self) -> None:
 
121
  self.client = Client()
122
  self.url = Optional[str]
123
  self.template = "{url}\n\n{content}"
 
13
 
14
  from gistillery.base import JobInput
15
  from gistillery.config import get_config
16
+ from gistillery.errors import ProcessingError
17
  from gistillery.media import download_yt_audio, load_audio
18
  from gistillery.tools import get_agent
19
 
 
33
 
34
 
35
  class Processor(abc.ABC):
36
+ def __init__(self) -> None:
37
+ self.max_length = get_config().processing_max_length
38
+ self._super_init_called = True
39
+
40
  def get_name(self) -> str:
41
  return self.__class__.__name__
42
 
43
  def __call__(self, job: JobInput) -> str:
44
+ if not self._super_init_called:
45
+ raise RuntimeError(
46
+ "super().__init__() was not called with class "
47
+ f"{self.__class__.__name__}"
48
+ )
49
+
50
  _id = job.id
51
  logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
52
  result = self.process(job)
53
+ if len(result) > self.max_length:
54
+ logger.warning(
55
+ f"Length of result ({len(result)}) exceeds max_length "
56
+ f"({self.max_length}), truncating"
57
+ )
58
+ result = result[: self.max_length]
59
  logger.info(f"Finished processing input (id={_id[:8]})")
60
  return result
61
 
 
78
 
79
  class DefaultUrlProcessor(Processor):
80
  def __init__(self) -> None:
81
+ super().__init__()
82
  self.client = Client()
83
  self.url = Optional[str]
84
  self.template = "{url}\n\n{content}"
 
103
  return str(text)
104
 
105
 
106
+ class PdfUrlProcessor(Processor):
107
+ def __init__(self) -> None:
108
+ super().__init__()
109
+ self.client = Client()
110
+ self.url = Optional[str]
111
+ self.template = "{url}\n\n{content}"
112
+ self.stop_words = get_config().pdf_stop_words
113
+
114
+ def match(self, input: JobInput) -> bool:
115
+ url = get_url(input.content.strip())
116
+ if url is None:
117
+ return False
118
+
119
+ suffix = url.rsplit(".", 1)[-1].lower()
120
+ if suffix != "pdf":
121
+ return False
122
+
123
+ self.url = url
124
+ return True
125
+
126
+ def process(self, input: JobInput) -> str:
127
+ if not isinstance(self.url, str):
128
+ raise TypeError("self.url must be a string")
129
+
130
+ response = self.client.get(self.url)
131
+ import pypdf
132
+
133
+ pdf = pypdf.PdfReader(io.BytesIO(response.content))
134
+ results = []
135
+ for page in pdf.pages:
136
+ results.append(page.extract_text())
137
+ if any(word in results[-1] for word in self.stop_words):
138
+ break
139
+ text = "\n".join(results).strip()
140
+ if not text:
141
+ raise ProcessingError("No text could be extracted from PDF")
142
+ return self.template.format(url=self.url, content=text)
143
+
144
+
145
  class ImageUrlProcessor(Processor):
146
  def __init__(self) -> None:
147
+ super().__init__()
148
  self.client = Client()
149
  self.url = Optional[str]
150
  self.template = "{url}\n\n{content}"
 
169
  response = self.client.get(self.url)
170
  image = Image.open(io.BytesIO(response.content)).convert('RGB')
171
  caption = get_agent().run("Caption the following image", image=image)
172
+ text = str(caption)
173
+ return self.template.format(url=self.url, content=text)
174
 
175
 
176
  class YoutubeUrlProcessor(Processor):
177
  """Download yt audio, transcribe with whisper"""
178
 
179
  def __init__(self) -> None:
180
+ super().__init__()
181
  self.client = Client()
182
  self.url = Optional[str]
183
  self.template = "{url}\n\n{content}"
src/gistillery/registry.py CHANGED
@@ -2,6 +2,7 @@ from gistillery.base import JobInput
2
  from gistillery.preprocessing import (
3
  DefaultUrlProcessor,
4
  ImageUrlProcessor,
 
5
  Processor,
6
  RawTextProcessor,
7
  YoutubeUrlProcessor,
@@ -59,6 +60,7 @@ def get_tool_registry() -> ToolRegistry:
59
 
60
  _registry = ToolRegistry()
61
  _registry.register_processor(YoutubeUrlProcessor())
 
62
  _registry.register_processor(ImageUrlProcessor())
63
  _registry.register_processor(DefaultUrlProcessor())
64
  _registry.register_processor(RawTextProcessor())
 
2
  from gistillery.preprocessing import (
3
  DefaultUrlProcessor,
4
  ImageUrlProcessor,
5
+ PdfUrlProcessor,
6
  Processor,
7
  RawTextProcessor,
8
  YoutubeUrlProcessor,
 
60
 
61
  _registry = ToolRegistry()
62
  _registry.register_processor(YoutubeUrlProcessor())
63
+ _registry.register_processor(PdfUrlProcessor())
64
  _registry.register_processor(ImageUrlProcessor())
65
  _registry.register_processor(DefaultUrlProcessor())
66
  _registry.register_processor(RawTextProcessor())