Jonathan Wang commited on
Commit
89cbc4d
·
0 Parent(s):

initial commit

Browse files
Files changed (26) hide show
  1. .gitattributes +7 -0
  2. .gitignore +171 -0
  3. .streamlit/config.toml +2 -0
  4. .vscode/launch.json +3 -0
  5. LICENSE +661 -0
  6. README.md +53 -0
  7. agent.py +92 -0
  8. app.py +471 -0
  9. citation.py +245 -0
  10. engine.py +126 -0
  11. full_doc.py +336 -0
  12. keywords.py +110 -0
  13. merger.py +174 -0
  14. metadata_adder.py +280 -0
  15. models.py +785 -0
  16. obs_logging.py +380 -0
  17. packages.txt +4 -0
  18. parsers.py +106 -0
  19. pdf_reader.py +528 -0
  20. pdf_reader_utils.py +592 -0
  21. prompts.py +86 -0
  22. pyproject.toml +53 -0
  23. requirements.txt +34 -0
  24. retriever.py +280 -0
  25. storage.py +120 -0
  26. summary.py +246 -0
.gitattributes ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ nltk_data/taggers/averaged_perceptron_tagger/averaged_perceptron_tagger.pickle filter=lfs diff=lfs merge=lfs -text
2
+ nltk_data/tokenizers/punkt/english.pickle filter=lfs diff=lfs merge=lfs -text
3
+ nltk_data/tokenizers/punkt/PY3/english.pickle filter=lfs diff=lfs merge=lfs -text
4
+ *.pickle filter=lfs diff=lfs merge=lfs -text
5
+ *.tab filter=lfs diff=lfs merge=lfs -text
6
+ *.json filter=lfs diff=lfs merge=lfs -text
7
+ *.zip filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##### LOCAL PROJECT FILES #####
2
+ data/
3
+ refs/
4
+ figures/
5
+ config.py
6
+ .streamlit/secrets.toml
7
+
8
+ ###############################
9
+
10
+ # Byte-compiled / optimized / DLL files
11
+ __pycache__/
12
+ *.py[cod]
13
+ *$py.class
14
+
15
+ # C extensions
16
+ *.so
17
+
18
+ # Distribution / packaging
19
+ .Python
20
+ build/
21
+ develop-eggs/
22
+ dist/
23
+ downloads/
24
+ eggs/
25
+ .eggs/
26
+ lib/
27
+ lib64/
28
+ parts/
29
+ sdist/
30
+ var/
31
+ wheels/
32
+ share/python-wheels/
33
+ *.egg-info/
34
+ .installed.cfg
35
+ *.egg
36
+ MANIFEST
37
+
38
+ # PyInstaller
39
+ # Usually these files are written by a python script from a template
40
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
41
+ *.manifest
42
+ *.spec
43
+
44
+ # Installer logs
45
+ pip-log.txt
46
+ pip-delete-this-directory.txt
47
+
48
+ # Unit test / coverage reports
49
+ htmlcov/
50
+ .tox/
51
+ .nox/
52
+ .coverage
53
+ .coverage.*
54
+ .cache
55
+ nosetests.xml
56
+ coverage.xml
57
+ *.cover
58
+ *.py,cover
59
+ .hypothesis/
60
+ .pytest_cache/
61
+ cover/
62
+
63
+ # Translations
64
+ *.mo
65
+ *.pot
66
+
67
+ # Django stuff:
68
+ *.log
69
+ local_settings.py
70
+ db.sqlite3
71
+ db.sqlite3-journal
72
+
73
+ # Flask stuff:
74
+ instance/
75
+ .webassets-cache
76
+
77
+ # Scrapy stuff:
78
+ .scrapy
79
+
80
+ # Sphinx documentation
81
+ docs/_build/
82
+
83
+ # PyBuilder
84
+ .pybuilder/
85
+ target/
86
+
87
+ # Jupyter Notebook
88
+ .ipynb_checkpoints
89
+
90
+ # IPython
91
+ profile_default/
92
+ ipython_config.py
93
+
94
+ # pyenv
95
+ # For a library or package, you might want to ignore these files since the code is
96
+ # intended to run in multiple environments; otherwise, check them in:
97
+ # .python-version
98
+
99
+ # pipenv
100
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
102
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
103
+ # install all needed dependencies.
104
+ #Pipfile.lock
105
+
106
+ # poetry
107
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
109
+ # commonly ignored for libraries.
110
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111
+ #poetry.lock
112
+
113
+ # pdm
114
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
115
+ #pdm.lock
116
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
117
+ # in version control.
118
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
119
+ .pdm.toml
120
+ .pdm-python
121
+ .pdm-build/
122
+
123
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
124
+ __pypackages__/
125
+
126
+ # Celery stuff
127
+ celerybeat-schedule
128
+ celerybeat.pid
129
+
130
+ # SageMath parsed files
131
+ *.sage.py
132
+
133
+ # Environments
134
+ .env
135
+ .venv
136
+ env/
137
+ venv/
138
+ ENV/
139
+ env.bak/
140
+ venv.bak/
141
+
142
+ # Spyder project settings
143
+ .spyderproject
144
+ .spyproject
145
+
146
+ # Rope project settings
147
+ .ropeproject
148
+
149
+ # mkdocs documentation
150
+ /site
151
+
152
+ # mypy
153
+ .mypy_cache/
154
+ .dmypy.json
155
+ dmypy.json
156
+
157
+ # Pyre type checker
158
+ .pyre/
159
+
160
+ # pytype static type analyzer
161
+ .pytype/
162
+
163
+ # Cython debug symbols
164
+ cython_debug/
165
+
166
+ # PyCharm
167
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
168
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
169
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
170
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
171
+ #.idea/
.streamlit/config.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [browser]
2
+ gatherUsageStats = false
.vscode/launch.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f435c38bfb7c91633a094d3ca2f8224839fb2151158536bda1ca0de4b395b426
3
+ size 624
LICENSE ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU AFFERO GENERAL PUBLIC LICENSE
2
+ Version 3, 19 November 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU Affero General Public License is a free, copyleft license for
11
+ software and other kinds of works, specifically designed to ensure
12
+ cooperation with the community in the case of network server software.
13
+
14
+ The licenses for most software and other practical works are designed
15
+ to take away your freedom to share and change the works. By contrast,
16
+ our General Public Licenses are intended to guarantee your freedom to
17
+ share and change all versions of a program--to make sure it remains free
18
+ software for all its users.
19
+
20
+ When we speak of free software, we are referring to freedom, not
21
+ price. Our General Public Licenses are designed to make sure that you
22
+ have the freedom to distribute copies of free software (and charge for
23
+ them if you wish), that you receive source code or can get it if you
24
+ want it, that you can change the software or use pieces of it in new
25
+ free programs, and that you know you can do these things.
26
+
27
+ Developers that use our General Public Licenses protect your rights
28
+ with two steps: (1) assert copyright on the software, and (2) offer
29
+ you this License which gives you legal permission to copy, distribute
30
+ and/or modify the software.
31
+
32
+ A secondary benefit of defending all users' freedom is that
33
+ improvements made in alternate versions of the program, if they
34
+ receive widespread use, become available for other developers to
35
+ incorporate. Many developers of free software are heartened and
36
+ encouraged by the resulting cooperation. However, in the case of
37
+ software used on network servers, this result may fail to come about.
38
+ The GNU General Public License permits making a modified version and
39
+ letting the public access it on a server without ever releasing its
40
+ source code to the public.
41
+
42
+ The GNU Affero General Public License is designed specifically to
43
+ ensure that, in such cases, the modified source code becomes available
44
+ to the community. It requires the operator of a network server to
45
+ provide the source code of the modified version running there to the
46
+ users of that server. Therefore, public use of a modified version, on
47
+ a publicly accessible server, gives the public access to the source
48
+ code of the modified version.
49
+
50
+ An older license, called the Affero General Public License and
51
+ published by Affero, was designed to accomplish similar goals. This is
52
+ a different license, not a version of the Affero GPL, but Affero has
53
+ released a new version of the Affero GPL which permits relicensing under
54
+ this license.
55
+
56
+ The precise terms and conditions for copying, distribution and
57
+ modification follow.
58
+
59
+ TERMS AND CONDITIONS
60
+
61
+ 0. Definitions.
62
+
63
+ "This License" refers to version 3 of the GNU Affero General Public License.
64
+
65
+ "Copyright" also means copyright-like laws that apply to other kinds of
66
+ works, such as semiconductor masks.
67
+
68
+ "The Program" refers to any copyrightable work licensed under this
69
+ License. Each licensee is addressed as "you". "Licensees" and
70
+ "recipients" may be individuals or organizations.
71
+
72
+ To "modify" a work means to copy from or adapt all or part of the work
73
+ in a fashion requiring copyright permission, other than the making of an
74
+ exact copy. The resulting work is called a "modified version" of the
75
+ earlier work or a work "based on" the earlier work.
76
+
77
+ A "covered work" means either the unmodified Program or a work based
78
+ on the Program.
79
+
80
+ To "propagate" a work means to do anything with it that, without
81
+ permission, would make you directly or secondarily liable for
82
+ infringement under applicable copyright law, except executing it on a
83
+ computer or modifying a private copy. Propagation includes copying,
84
+ distribution (with or without modification), making available to the
85
+ public, and in some countries other activities as well.
86
+
87
+ To "convey" a work means any kind of propagation that enables other
88
+ parties to make or receive copies. Mere interaction with a user through
89
+ a computer network, with no transfer of a copy, is not conveying.
90
+
91
+ An interactive user interface displays "Appropriate Legal Notices"
92
+ to the extent that it includes a convenient and prominently visible
93
+ feature that (1) displays an appropriate copyright notice, and (2)
94
+ tells the user that there is no warranty for the work (except to the
95
+ extent that warranties are provided), that licensees may convey the
96
+ work under this License, and how to view a copy of this License. If
97
+ the interface presents a list of user commands or options, such as a
98
+ menu, a prominent item in the list meets this criterion.
99
+
100
+ 1. Source Code.
101
+
102
+ The "source code" for a work means the preferred form of the work
103
+ for making modifications to it. "Object code" means any non-source
104
+ form of a work.
105
+
106
+ A "Standard Interface" means an interface that either is an official
107
+ standard defined by a recognized standards body, or, in the case of
108
+ interfaces specified for a particular programming language, one that
109
+ is widely used among developers working in that language.
110
+
111
+ The "System Libraries" of an executable work include anything, other
112
+ than the work as a whole, that (a) is included in the normal form of
113
+ packaging a Major Component, but which is not part of that Major
114
+ Component, and (b) serves only to enable use of the work with that
115
+ Major Component, or to implement a Standard Interface for which an
116
+ implementation is available to the public in source code form. A
117
+ "Major Component", in this context, means a major essential component
118
+ (kernel, window system, and so on) of the specific operating system
119
+ (if any) on which the executable work runs, or a compiler used to
120
+ produce the work, or an object code interpreter used to run it.
121
+
122
+ The "Corresponding Source" for a work in object code form means all
123
+ the source code needed to generate, install, and (for an executable
124
+ work) run the object code and to modify the work, including scripts to
125
+ control those activities. However, it does not include the work's
126
+ System Libraries, or general-purpose tools or generally available free
127
+ programs which are used unmodified in performing those activities but
128
+ which are not part of the work. For example, Corresponding Source
129
+ includes interface definition files associated with source files for
130
+ the work, and the source code for shared libraries and dynamically
131
+ linked subprograms that the work is specifically designed to require,
132
+ such as by intimate data communication or control flow between those
133
+ subprograms and other parts of the work.
134
+
135
+ The Corresponding Source need not include anything that users
136
+ can regenerate automatically from other parts of the Corresponding
137
+ Source.
138
+
139
+ The Corresponding Source for a work in source code form is that
140
+ same work.
141
+
142
+ 2. Basic Permissions.
143
+
144
+ All rights granted under this License are granted for the term of
145
+ copyright on the Program, and are irrevocable provided the stated
146
+ conditions are met. This License explicitly affirms your unlimited
147
+ permission to run the unmodified Program. The output from running a
148
+ covered work is covered by this License only if the output, given its
149
+ content, constitutes a covered work. This License acknowledges your
150
+ rights of fair use or other equivalent, as provided by copyright law.
151
+
152
+ You may make, run and propagate covered works that you do not
153
+ convey, without conditions so long as your license otherwise remains
154
+ in force. You may convey covered works to others for the sole purpose
155
+ of having them make modifications exclusively for you, or provide you
156
+ with facilities for running those works, provided that you comply with
157
+ the terms of this License in conveying all material for which you do
158
+ not control copyright. Those thus making or running the covered works
159
+ for you must do so exclusively on your behalf, under your direction
160
+ and control, on terms that prohibit them from making any copies of
161
+ your copyrighted material outside their relationship with you.
162
+
163
+ Conveying under any other circumstances is permitted solely under
164
+ the conditions stated below. Sublicensing is not allowed; section 10
165
+ makes it unnecessary.
166
+
167
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168
+
169
+ No covered work shall be deemed part of an effective technological
170
+ measure under any applicable law fulfilling obligations under article
171
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172
+ similar laws prohibiting or restricting circumvention of such
173
+ measures.
174
+
175
+ When you convey a covered work, you waive any legal power to forbid
176
+ circumvention of technological measures to the extent such circumvention
177
+ is effected by exercising rights under this License with respect to
178
+ the covered work, and you disclaim any intention to limit operation or
179
+ modification of the work as a means of enforcing, against the work's
180
+ users, your or third parties' legal rights to forbid circumvention of
181
+ technological measures.
182
+
183
+ 4. Conveying Verbatim Copies.
184
+
185
+ You may convey verbatim copies of the Program's source code as you
186
+ receive it, in any medium, provided that you conspicuously and
187
+ appropriately publish on each copy an appropriate copyright notice;
188
+ keep intact all notices stating that this License and any
189
+ non-permissive terms added in accord with section 7 apply to the code;
190
+ keep intact all notices of the absence of any warranty; and give all
191
+ recipients a copy of this License along with the Program.
192
+
193
+ You may charge any price or no price for each copy that you convey,
194
+ and you may offer support or warranty protection for a fee.
195
+
196
+ 5. Conveying Modified Source Versions.
197
+
198
+ You may convey a work based on the Program, or the modifications to
199
+ produce it from the Program, in the form of source code under the
200
+ terms of section 4, provided that you also meet all of these conditions:
201
+
202
+ a) The work must carry prominent notices stating that you modified
203
+ it, and giving a relevant date.
204
+
205
+ b) The work must carry prominent notices stating that it is
206
+ released under this License and any conditions added under section
207
+ 7. This requirement modifies the requirement in section 4 to
208
+ "keep intact all notices".
209
+
210
+ c) You must license the entire work, as a whole, under this
211
+ License to anyone who comes into possession of a copy. This
212
+ License will therefore apply, along with any applicable section 7
213
+ additional terms, to the whole of the work, and all its parts,
214
+ regardless of how they are packaged. This License gives no
215
+ permission to license the work in any other way, but it does not
216
+ invalidate such permission if you have separately received it.
217
+
218
+ d) If the work has interactive user interfaces, each must display
219
+ Appropriate Legal Notices; however, if the Program has interactive
220
+ interfaces that do not display Appropriate Legal Notices, your
221
+ work need not make them do so.
222
+
223
+ A compilation of a covered work with other separate and independent
224
+ works, which are not by their nature extensions of the covered work,
225
+ and which are not combined with it such as to form a larger program,
226
+ in or on a volume of a storage or distribution medium, is called an
227
+ "aggregate" if the compilation and its resulting copyright are not
228
+ used to limit the access or legal rights of the compilation's users
229
+ beyond what the individual works permit. Inclusion of a covered work
230
+ in an aggregate does not cause this License to apply to the other
231
+ parts of the aggregate.
232
+
233
+ 6. Conveying Non-Source Forms.
234
+
235
+ You may convey a covered work in object code form under the terms
236
+ of sections 4 and 5, provided that you also convey the
237
+ machine-readable Corresponding Source under the terms of this License,
238
+ in one of these ways:
239
+
240
+ a) Convey the object code in, or embodied in, a physical product
241
+ (including a physical distribution medium), accompanied by the
242
+ Corresponding Source fixed on a durable physical medium
243
+ customarily used for software interchange.
244
+
245
+ b) Convey the object code in, or embodied in, a physical product
246
+ (including a physical distribution medium), accompanied by a
247
+ written offer, valid for at least three years and valid for as
248
+ long as you offer spare parts or customer support for that product
249
+ model, to give anyone who possesses the object code either (1) a
250
+ copy of the Corresponding Source for all the software in the
251
+ product that is covered by this License, on a durable physical
252
+ medium customarily used for software interchange, for a price no
253
+ more than your reasonable cost of physically performing this
254
+ conveying of source, or (2) access to copy the
255
+ Corresponding Source from a network server at no charge.
256
+
257
+ c) Convey individual copies of the object code with a copy of the
258
+ written offer to provide the Corresponding Source. This
259
+ alternative is allowed only occasionally and noncommercially, and
260
+ only if you received the object code with such an offer, in accord
261
+ with subsection 6b.
262
+
263
+ d) Convey the object code by offering access from a designated
264
+ place (gratis or for a charge), and offer equivalent access to the
265
+ Corresponding Source in the same way through the same place at no
266
+ further charge. You need not require recipients to copy the
267
+ Corresponding Source along with the object code. If the place to
268
+ copy the object code is a network server, the Corresponding Source
269
+ may be on a different server (operated by you or a third party)
270
+ that supports equivalent copying facilities, provided you maintain
271
+ clear directions next to the object code saying where to find the
272
+ Corresponding Source. Regardless of what server hosts the
273
+ Corresponding Source, you remain obligated to ensure that it is
274
+ available for as long as needed to satisfy these requirements.
275
+
276
+ e) Convey the object code using peer-to-peer transmission, provided
277
+ you inform other peers where the object code and Corresponding
278
+ Source of the work are being offered to the general public at no
279
+ charge under subsection 6d.
280
+
281
+ A separable portion of the object code, whose source code is excluded
282
+ from the Corresponding Source as a System Library, need not be
283
+ included in conveying the object code work.
284
+
285
+ A "User Product" is either (1) a "consumer product", which means any
286
+ tangible personal property which is normally used for personal, family,
287
+ or household purposes, or (2) anything designed or sold for incorporation
288
+ into a dwelling. In determining whether a product is a consumer product,
289
+ doubtful cases shall be resolved in favor of coverage. For a particular
290
+ product received by a particular user, "normally used" refers to a
291
+ typical or common use of that class of product, regardless of the status
292
+ of the particular user or of the way in which the particular user
293
+ actually uses, or expects or is expected to use, the product. A product
294
+ is a consumer product regardless of whether the product has substantial
295
+ commercial, industrial or non-consumer uses, unless such uses represent
296
+ the only significant mode of use of the product.
297
+
298
+ "Installation Information" for a User Product means any methods,
299
+ procedures, authorization keys, or other information required to install
300
+ and execute modified versions of a covered work in that User Product from
301
+ a modified version of its Corresponding Source. The information must
302
+ suffice to ensure that the continued functioning of the modified object
303
+ code is in no case prevented or interfered with solely because
304
+ modification has been made.
305
+
306
+ If you convey an object code work under this section in, or with, or
307
+ specifically for use in, a User Product, and the conveying occurs as
308
+ part of a transaction in which the right of possession and use of the
309
+ User Product is transferred to the recipient in perpetuity or for a
310
+ fixed term (regardless of how the transaction is characterized), the
311
+ Corresponding Source conveyed under this section must be accompanied
312
+ by the Installation Information. But this requirement does not apply
313
+ if neither you nor any third party retains the ability to install
314
+ modified object code on the User Product (for example, the work has
315
+ been installed in ROM).
316
+
317
+ The requirement to provide Installation Information does not include a
318
+ requirement to continue to provide support service, warranty, or updates
319
+ for a work that has been modified or installed by the recipient, or for
320
+ the User Product in which it has been modified or installed. Access to a
321
+ network may be denied when the modification itself materially and
322
+ adversely affects the operation of the network or violates the rules and
323
+ protocols for communication across the network.
324
+
325
+ Corresponding Source conveyed, and Installation Information provided,
326
+ in accord with this section must be in a format that is publicly
327
+ documented (and with an implementation available to the public in
328
+ source code form), and must require no special password or key for
329
+ unpacking, reading or copying.
330
+
331
+ 7. Additional Terms.
332
+
333
+ "Additional permissions" are terms that supplement the terms of this
334
+ License by making exceptions from one or more of its conditions.
335
+ Additional permissions that are applicable to the entire Program shall
336
+ be treated as though they were included in this License, to the extent
337
+ that they are valid under applicable law. If additional permissions
338
+ apply only to part of the Program, that part may be used separately
339
+ under those permissions, but the entire Program remains governed by
340
+ this License without regard to the additional permissions.
341
+
342
+ When you convey a copy of a covered work, you may at your option
343
+ remove any additional permissions from that copy, or from any part of
344
+ it. (Additional permissions may be written to require their own
345
+ removal in certain cases when you modify the work.) You may place
346
+ additional permissions on material, added by you to a covered work,
347
+ for which you have or can give appropriate copyright permission.
348
+
349
+ Notwithstanding any other provision of this License, for material you
350
+ add to a covered work, you may (if authorized by the copyright holders of
351
+ that material) supplement the terms of this License with terms:
352
+
353
+ a) Disclaiming warranty or limiting liability differently from the
354
+ terms of sections 15 and 16 of this License; or
355
+
356
+ b) Requiring preservation of specified reasonable legal notices or
357
+ author attributions in that material or in the Appropriate Legal
358
+ Notices displayed by works containing it; or
359
+
360
+ c) Prohibiting misrepresentation of the origin of that material, or
361
+ requiring that modified versions of such material be marked in
362
+ reasonable ways as different from the original version; or
363
+
364
+ d) Limiting the use for publicity purposes of names of licensors or
365
+ authors of the material; or
366
+
367
+ e) Declining to grant rights under trademark law for use of some
368
+ trade names, trademarks, or service marks; or
369
+
370
+ f) Requiring indemnification of licensors and authors of that
371
+ material by anyone who conveys the material (or modified versions of
372
+ it) with contractual assumptions of liability to the recipient, for
373
+ any liability that these contractual assumptions directly impose on
374
+ those licensors and authors.
375
+
376
+ All other non-permissive additional terms are considered "further
377
+ restrictions" within the meaning of section 10. If the Program as you
378
+ received it, or any part of it, contains a notice stating that it is
379
+ governed by this License along with a term that is a further
380
+ restriction, you may remove that term. If a license document contains
381
+ a further restriction but permits relicensing or conveying under this
382
+ License, you may add to a covered work material governed by the terms
383
+ of that license document, provided that the further restriction does
384
+ not survive such relicensing or conveying.
385
+
386
+ If you add terms to a covered work in accord with this section, you
387
+ must place, in the relevant source files, a statement of the
388
+ additional terms that apply to those files, or a notice indicating
389
+ where to find the applicable terms.
390
+
391
+ Additional terms, permissive or non-permissive, may be stated in the
392
+ form of a separately written license, or stated as exceptions;
393
+ the above requirements apply either way.
394
+
395
+ 8. Termination.
396
+
397
+ You may not propagate or modify a covered work except as expressly
398
+ provided under this License. Any attempt otherwise to propagate or
399
+ modify it is void, and will automatically terminate your rights under
400
+ this License (including any patent licenses granted under the third
401
+ paragraph of section 11).
402
+
403
+ However, if you cease all violation of this License, then your
404
+ license from a particular copyright holder is reinstated (a)
405
+ provisionally, unless and until the copyright holder explicitly and
406
+ finally terminates your license, and (b) permanently, if the copyright
407
+ holder fails to notify you of the violation by some reasonable means
408
+ prior to 60 days after the cessation.
409
+
410
+ Moreover, your license from a particular copyright holder is
411
+ reinstated permanently if the copyright holder notifies you of the
412
+ violation by some reasonable means, this is the first time you have
413
+ received notice of violation of this License (for any work) from that
414
+ copyright holder, and you cure the violation prior to 30 days after
415
+ your receipt of the notice.
416
+
417
+ Termination of your rights under this section does not terminate the
418
+ licenses of parties who have received copies or rights from you under
419
+ this License. If your rights have been terminated and not permanently
420
+ reinstated, you do not qualify to receive new licenses for the same
421
+ material under section 10.
422
+
423
+ 9. Acceptance Not Required for Having Copies.
424
+
425
+ You are not required to accept this License in order to receive or
426
+ run a copy of the Program. Ancillary propagation of a covered work
427
+ occurring solely as a consequence of using peer-to-peer transmission
428
+ to receive a copy likewise does not require acceptance. However,
429
+ nothing other than this License grants you permission to propagate or
430
+ modify any covered work. These actions infringe copyright if you do
431
+ not accept this License. Therefore, by modifying or propagating a
432
+ covered work, you indicate your acceptance of this License to do so.
433
+
434
+ 10. Automatic Licensing of Downstream Recipients.
435
+
436
+ Each time you convey a covered work, the recipient automatically
437
+ receives a license from the original licensors, to run, modify and
438
+ propagate that work, subject to this License. You are not responsible
439
+ for enforcing compliance by third parties with this License.
440
+
441
+ An "entity transaction" is a transaction transferring control of an
442
+ organization, or substantially all assets of one, or subdividing an
443
+ organization, or merging organizations. If propagation of a covered
444
+ work results from an entity transaction, each party to that
445
+ transaction who receives a copy of the work also receives whatever
446
+ licenses to the work the party's predecessor in interest had or could
447
+ give under the previous paragraph, plus a right to possession of the
448
+ Corresponding Source of the work from the predecessor in interest, if
449
+ the predecessor has it or can get it with reasonable efforts.
450
+
451
+ You may not impose any further restrictions on the exercise of the
452
+ rights granted or affirmed under this License. For example, you may
453
+ not impose a license fee, royalty, or other charge for exercise of
454
+ rights granted under this License, and you may not initiate litigation
455
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
456
+ any patent claim is infringed by making, using, selling, offering for
457
+ sale, or importing the Program or any portion of it.
458
+
459
+ 11. Patents.
460
+
461
+ A "contributor" is a copyright holder who authorizes use under this
462
+ License of the Program or a work on which the Program is based. The
463
+ work thus licensed is called the contributor's "contributor version".
464
+
465
+ A contributor's "essential patent claims" are all patent claims
466
+ owned or controlled by the contributor, whether already acquired or
467
+ hereafter acquired, that would be infringed by some manner, permitted
468
+ by this License, of making, using, or selling its contributor version,
469
+ but do not include claims that would be infringed only as a
470
+ consequence of further modification of the contributor version. For
471
+ purposes of this definition, "control" includes the right to grant
472
+ patent sublicenses in a manner consistent with the requirements of
473
+ this License.
474
+
475
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
476
+ patent license under the contributor's essential patent claims, to
477
+ make, use, sell, offer for sale, import and otherwise run, modify and
478
+ propagate the contents of its contributor version.
479
+
480
+ In the following three paragraphs, a "patent license" is any express
481
+ agreement or commitment, however denominated, not to enforce a patent
482
+ (such as an express permission to practice a patent or covenant not to
483
+ sue for patent infringement). To "grant" such a patent license to a
484
+ party means to make such an agreement or commitment not to enforce a
485
+ patent against the party.
486
+
487
+ If you convey a covered work, knowingly relying on a patent license,
488
+ and the Corresponding Source of the work is not available for anyone
489
+ to copy, free of charge and under the terms of this License, through a
490
+ publicly available network server or other readily accessible means,
491
+ then you must either (1) cause the Corresponding Source to be so
492
+ available, or (2) arrange to deprive yourself of the benefit of the
493
+ patent license for this particular work, or (3) arrange, in a manner
494
+ consistent with the requirements of this License, to extend the patent
495
+ license to downstream recipients. "Knowingly relying" means you have
496
+ actual knowledge that, but for the patent license, your conveying the
497
+ covered work in a country, or your recipient's use of the covered work
498
+ in a country, would infringe one or more identifiable patents in that
499
+ country that you have reason to believe are valid.
500
+
501
+ If, pursuant to or in connection with a single transaction or
502
+ arrangement, you convey, or propagate by procuring conveyance of, a
503
+ covered work, and grant a patent license to some of the parties
504
+ receiving the covered work authorizing them to use, propagate, modify
505
+ or convey a specific copy of the covered work, then the patent license
506
+ you grant is automatically extended to all recipients of the covered
507
+ work and works based on it.
508
+
509
+ A patent license is "discriminatory" if it does not include within
510
+ the scope of its coverage, prohibits the exercise of, or is
511
+ conditioned on the non-exercise of one or more of the rights that are
512
+ specifically granted under this License. You may not convey a covered
513
+ work if you are a party to an arrangement with a third party that is
514
+ in the business of distributing software, under which you make payment
515
+ to the third party based on the extent of your activity of conveying
516
+ the work, and under which the third party grants, to any of the
517
+ parties who would receive the covered work from you, a discriminatory
518
+ patent license (a) in connection with copies of the covered work
519
+ conveyed by you (or copies made from those copies), or (b) primarily
520
+ for and in connection with specific products or compilations that
521
+ contain the covered work, unless you entered into that arrangement,
522
+ or that patent license was granted, prior to 28 March 2007.
523
+
524
+ Nothing in this License shall be construed as excluding or limiting
525
+ any implied license or other defenses to infringement that may
526
+ otherwise be available to you under applicable patent law.
527
+
528
+ 12. No Surrender of Others' Freedom.
529
+
530
+ If conditions are imposed on you (whether by court order, agreement or
531
+ otherwise) that contradict the conditions of this License, they do not
532
+ excuse you from the conditions of this License. If you cannot convey a
533
+ covered work so as to satisfy simultaneously your obligations under this
534
+ License and any other pertinent obligations, then as a consequence you may
535
+ not convey it at all. For example, if you agree to terms that obligate you
536
+ to collect a royalty for further conveying from those to whom you convey
537
+ the Program, the only way you could satisfy both those terms and this
538
+ License would be to refrain entirely from conveying the Program.
539
+
540
+ 13. Remote Network Interaction; Use with the GNU General Public License.
541
+
542
+ Notwithstanding any other provision of this License, if you modify the
543
+ Program, your modified version must prominently offer all users
544
+ interacting with it remotely through a computer network (if your version
545
+ supports such interaction) an opportunity to receive the Corresponding
546
+ Source of your version by providing access to the Corresponding Source
547
+ from a network server at no charge, through some standard or customary
548
+ means of facilitating copying of software. This Corresponding Source
549
+ shall include the Corresponding Source for any work covered by version 3
550
+ of the GNU General Public License that is incorporated pursuant to the
551
+ following paragraph.
552
+
553
+ Notwithstanding any other provision of this License, you have
554
+ permission to link or combine any covered work with a work licensed
555
+ under version 3 of the GNU General Public License into a single
556
+ combined work, and to convey the resulting work. The terms of this
557
+ License will continue to apply to the part which is the covered work,
558
+ but the work with which it is combined will remain governed by version
559
+ 3 of the GNU General Public License.
560
+
561
+ 14. Revised Versions of this License.
562
+
563
+ The Free Software Foundation may publish revised and/or new versions of
564
+ the GNU Affero General Public License from time to time. Such new versions
565
+ will be similar in spirit to the present version, but may differ in detail to
566
+ address new problems or concerns.
567
+
568
+ Each version is given a distinguishing version number. If the
569
+ Program specifies that a certain numbered version of the GNU Affero General
570
+ Public License "or any later version" applies to it, you have the
571
+ option of following the terms and conditions either of that numbered
572
+ version or of any later version published by the Free Software
573
+ Foundation. If the Program does not specify a version number of the
574
+ GNU Affero General Public License, you may choose any version ever published
575
+ by the Free Software Foundation.
576
+
577
+ If the Program specifies that a proxy can decide which future
578
+ versions of the GNU Affero General Public License can be used, that proxy's
579
+ public statement of acceptance of a version permanently authorizes you
580
+ to choose that version for the Program.
581
+
582
+ Later license versions may give you additional or different
583
+ permissions. However, no additional obligations are imposed on any
584
+ author or copyright holder as a result of your choosing to follow a
585
+ later version.
586
+
587
+ 15. Disclaimer of Warranty.
588
+
589
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597
+
598
+ 16. Limitation of Liability.
599
+
600
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608
+ SUCH DAMAGES.
609
+
610
+ 17. Interpretation of Sections 15 and 16.
611
+
612
+ If the disclaimer of warranty and limitation of liability provided
613
+ above cannot be given local legal effect according to their terms,
614
+ reviewing courts shall apply local law that most closely approximates
615
+ an absolute waiver of all civil liability in connection with the
616
+ Program, unless a warranty or assumption of liability accompanies a
617
+ copy of the Program in return for a fee.
618
+
619
+ END OF TERMS AND CONDITIONS
620
+
621
+ How to Apply These Terms to Your New Programs
622
+
623
+ If you develop a new program, and you want it to be of the greatest
624
+ possible use to the public, the best way to achieve this is to make it
625
+ free software which everyone can redistribute and change under these terms.
626
+
627
+ To do so, attach the following notices to the program. It is safest
628
+ to attach them to the start of each source file to most effectively
629
+ state the exclusion of warranty; and each file should have at least
630
+ the "copyright" line and a pointer to where the full notice is found.
631
+
632
+ <one line to give the program's name and a brief idea of what it does.>
633
+ Copyright (C) <year> <name of author>
634
+
635
+ This program is free software: you can redistribute it and/or modify
636
+ it under the terms of the GNU Affero General Public License as published
637
+ by the Free Software Foundation, either version 3 of the License, or
638
+ (at your option) any later version.
639
+
640
+ This program is distributed in the hope that it will be useful,
641
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
642
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643
+ GNU Affero General Public License for more details.
644
+
645
+ You should have received a copy of the GNU Affero General Public License
646
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
647
+
648
+ Also add information on how to contact you by electronic and paper mail.
649
+
650
+ If your software can interact with users remotely through a computer
651
+ network, you should also make sure that it provides a way for users to
652
+ get its source. For example, if your program is a web application, its
653
+ interface could display a "Source" link that leads users to an archive
654
+ of the code. There are many ways you could offer source, and different
655
+ solutions will be better for different programs; see section 13 for the
656
+ specific requirements.
657
+
658
+ You should also get your employer (if you work as a programmer) or school,
659
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
660
+ For more information on this, and how to apply and follow the GNU AGPL, see
661
+ <https://www.gnu.org/licenses/>.
README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Autodoc Lifter
3
+ emoji: 🦊📝
4
+ colorFrom: yellow
5
+ colorTo: red
6
+ python_version: 3.11.9
7
+ sdk: streamlit
8
+ sdk_version: 1.37.1
9
+ suggested_hardware: t4-small
10
+ suggested_storage: small
11
+ app_file: app.py
12
+ header: mini
13
+ short_description: Good Local RAG for Bad PDFs
14
+ models: [timm/resnet18.a1_in1k, microsoft/table-transformer-detection, mixedbread-ai/mxbai-embed-large-v1, mixedbread-ai/mxbai-rerank-large-v1, meta-llama/Meta-Llama-3.1-8B-Instruct, Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5]
15
+ tags: [rag, llm, pdf, document]
16
+ license: agpl-3.0
17
+ pinned: true
18
+ preload_from_hub:
19
+ - timm/resnet18.a1_in1k
20
+ - microsoft/table-transformer-detection
21
+ - mixedbread-ai/mxbai-embed-large-v1
22
+ - mixedbread-ai/mxbai-rerank-large-v1
23
+ - Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5
24
+ ---
25
+
26
+ ## Autodoc Lifter
27
+
28
+ Document RAG system with LLMs.
29
+ Some key goals for the project, once finished:
30
+
31
+ 0. All open, all local.
32
+ I don't want to be calling APIs. You can the entire app locally, and inspect the code and models.
33
+ This is particularly suitable for handling restricted information.
34
+ Yes I know this is a web demo on Spaces, so don't actually do that here.
35
+ Use the GitHub link: (here, once it's no longer ClosedAI)
36
+
37
+ 1. Support for atrocious and varied PDFs.
38
+ Have images? Have tables? Have a set of PDFs with the worst quality and page layout known to man?
39
+ Give it a try in here. I've been slowly building out custom processing for difficult documents by connecting Unstructured.IO to LlamaIndex in a slightly useful way.
40
+ (A future dream: get rid of Unstructured and build our own pipeline one day.)
41
+
42
+ 2. Multiple PDFs, handled with agents.
43
+ Instead of dumping all the documents into one central vector store and praying it works out,
44
+ I'm try to be more thoughtful as to how to incorporate multiple documents.
45
+
46
+ 3. Answers that are sourced and verifiable.
47
+ I'm sorry, but as an Definitely Human Person, I don't like hallucinated answers-ex-machina.
48
+ Responses should give actual citations \[0\] when pulling text directly from source documents,
49
+ and there should be a way to view the citations, referenced text, and the document itself.
50
+
51
+ --- CITATIONS ---
52
+ \[0\] Relies primarily on fuzzy string matching, because it's computationally cheaper and also
53
+ ensures that cited text actually occurs in the source documents.
agent.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ ### DOCUMENT PROCESSOR [AGENT]
3
+ #####################################################
4
+ ### Jonathan Wang
5
+
6
+ # ABOUT:
7
+ # This creates an app to chat with PDFs.
8
+
9
+ # This is the AGENT
10
+ # which handles complex questions about the PDF.
11
+ #####################################################
12
+ ### TODO Board:
13
+ # https://docs.llamaindex.ai/en/stable/examples/agent/agent_runner/agent_runner_rag_controllable/#setup-human-in-the-loop-chat
14
+ # Investigate ObjectIndex and retrievers? https://docs.llamaindex.ai/en/stable/examples/agent/multi_document_agents/
15
+ # https://docs.llamaindex.ai/en/stable/module_guides/storing/chat_stores/
16
+
17
+ #####################################################
18
+ ### IMPORTS
19
+ from typing import List
20
+
21
+ from streamlit import session_state as ss
22
+
23
+ from llama_index.core.settings import Settings
24
+ from llama_index.core.tools import QueryEngineTool, ToolMetadata
25
+ from llama_index.core.query_engine import SubQuestionQueryEngine
26
+
27
+ # Own Modules
28
+ from full_doc import FullDocument
29
+
30
+ #####################################################
31
+ ### CODE
32
+
33
+ ALLOWED_DOCUMENT_TOOLS = ['engine', 'subquestion_engine']
34
+ ALLOWED_TOOLS = ALLOWED_DOCUMENT_TOOLS
35
+
36
+ def _build_tool_from_fulldoc(fulldoc: FullDocument, tool_name: str) -> QueryEngineTool:
37
+ """Given a Full Document, build a QueryEngineTool from the specified engine.
38
+
39
+ Args:
40
+ fulldoc (FullDocument): The FullDocument (doc + query engines)
41
+ tool_name (str): The engine to use.
42
+
43
+ Returns:
44
+ QueryEngineTool: A query engine wrapper around the tool.
45
+ """
46
+ if (tool_name.lower() not in ALLOWED_DOCUMENT_TOOLS):
47
+ raise ValueError("`tool_name` must be one of {ALLOWED_DOCUMENT_TOOLS}")
48
+ if (getattr(fulldoc, tool_name, None) is None):
49
+ raise ValueError(f"`{tool_name}` must be created from the document first.")
50
+
51
+ # Build Tool
52
+ tool_description = ''
53
+ if tool_name == 'engine':
54
+ tool_description += 'A tool that answers simple questions about the following document:\n' + fulldoc.summary_oneline
55
+ elif tool_name == 'subquestion_engine':
56
+ tool_description += 'A tool that answers complex questions about the following document:\n' + fulldoc.summary_oneline
57
+
58
+ tool = QueryEngineTool(
59
+ query_engine=getattr(fulldoc, tool_name),
60
+ metadata=ToolMetadata(
61
+ name=tool_name,
62
+ description=tool_description
63
+ ),
64
+ )
65
+ return tool
66
+
67
+ def doclist_to_agent(doclist: List[FullDocument], fulldoc_tools_to_use: List[str]=['engine']) -> SubQuestionQueryEngine: # ReActAgent:
68
+ # Agent Tools
69
+ agent_tools = []
70
+
71
+ # Remove any tools that are not in the allowed list using
72
+ tools_to_use = list(set(fulldoc_tools_to_use).intersection(set(ALLOWED_DOCUMENT_TOOLS)))
73
+ if (len(tools_to_use) < len(fulldoc_tools_to_use)):
74
+ removed_tools = set(fulldoc_tools_to_use) - set(ALLOWED_DOCUMENT_TOOLS)
75
+ Warning(f"Tools {removed_tools} are not in the allowed list of tools. Skipping...")
76
+ del removed_tools
77
+
78
+ for tool in tools_to_use:
79
+ for doc in doclist:
80
+ agent_tools.append(_build_tool_from_fulldoc(doc, tool))
81
+
82
+ # Agent
83
+ # agent = ReActAgent.from_tools(
84
+ agent = SubQuestionQueryEngine.from_defaults(
85
+ # tools=agent_tools,
86
+ query_engine_tools=agent_tools,
87
+ llm=Settings.llm or ss.llm,
88
+ verbose=True,
89
+ # max_iterations=5
90
+ )
91
+
92
+ return agent
app.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ ### DOCUMENT PROCESSOR [APP]
3
+ #####################################################
4
+ ### Jonathan Wang
5
+
6
+ # ABOUT:
7
+ # This creates an app to chat with PDFs.
8
+
9
+ # This is the APP
10
+ # which runs the backend and codes the frontend UI.
11
+ #####################################################
12
+ ### TODO Board:
13
+ # Try ColPali? https://huggingface.co/vidore/colpali
14
+
15
+ #####################################################
16
+ ### PROGRAM IMPORTS
17
+ from __future__ import annotations
18
+
19
+ import base64
20
+ import gc
21
+ import logging
22
+ import os
23
+ import random
24
+ import sys
25
+ import warnings
26
+ from pathlib import Path
27
+ from typing import Any, cast
28
+
29
+ import nest_asyncio
30
+ import numpy as np
31
+ import streamlit as st
32
+ from llama_index.core import Settings, get_response_synthesizer
33
+ from llama_index.core.base.llms import BaseLLM
34
+ from llama_index.core.postprocessor import (
35
+ SentenceEmbeddingOptimizer,
36
+ SimilarityPostprocessor,
37
+ )
38
+ from llama_index.core.response_synthesizers import ResponseMode
39
+ from streamlit import session_state as ss
40
+ from summary import (
41
+ ImageSummaryMetadataAdder,
42
+ TableSummaryMetadataAdder,
43
+ get_tree_summarizer,
44
+ )
45
+ from torch.cuda import (
46
+ empty_cache,
47
+ get_device_name,
48
+ is_available,
49
+ manual_seed,
50
+ mem_get_info,
51
+ )
52
+ from transformers import set_seed
53
+
54
+ # Own Modules
55
+ from agent import doclist_to_agent
56
+ from citation import get_citation_builder
57
+ from full_doc import FullDocument
58
+ from keywords import KeywordMetadataAdder
59
+ from metadata_adder import UnstructuredPDFPostProcessor
60
+ from models import get_embedder, get_llm, get_multimodal_llm, get_reranker
61
+ from obs_logging import get_callback_manager, get_obs
62
+ from pdf_reader import UnstructuredPDFReader
63
+ from pdf_reader_utils import (
64
+ chunk_by_header,
65
+ clean_abbreviations,
66
+ combine_listitem_chunks,
67
+ dedupe_title_chunks,
68
+ remove_header_footer_repeated,
69
+ )
70
+ from parsers import get_parser
71
+ from prompts import get_qa_prompt, get_refine_prompt
72
+
73
+ #####################################
74
+ ### SETTINGS
75
+ # Logging
76
+ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
77
+ logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
78
+
79
+ # CUDA GPU memory avoid fragmentation.
80
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # avoid vram frag
81
+ os.environ["MAX_SPLIT_SIZE_MB"] = "128"
82
+ os.environ["SCARF_NO_ANALYTICS"] = "true" # get rid of data collection from Unstructured
83
+ os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
84
+
85
+ os.environ["HF_HOME"] = "/data/.huggingface" # save cached models on disk.
86
+
87
+ SEED = 31415926
88
+
89
+ print(f"CUDA Availablility: {is_available()}")
90
+ print(f"CUDA Device Name: {get_device_name()}")
91
+ print(f"CUDA Memory: {mem_get_info()}")
92
+
93
+ gc.collect()
94
+ empty_cache()
95
+
96
+ # Asyncio: fix some issues with nesting https://github.com/run-llama/llama_index/issues/9978
97
+ nest_asyncio.apply()
98
+
99
+ # Set seeds
100
+ if (random.getstate() is None):
101
+ random.seed(SEED) # python
102
+ np.random.seed(SEED) # numpy # TODO(Jonathan Wang): Replace with generator
103
+ manual_seed(SEED) # pytorch
104
+ set_seed(SEED) # transformers
105
+
106
+ # API Keys
107
+ os.environ["HF_TOKEN"] = st.secrets["huggingface_api_token"]
108
+ os.environ["OPENAI_API_KEY"] = st.secrets["openai_api_key"]
109
+ os.environ["GROQ_API_KEY"] = st.secrets["groq_api_key"]
110
+
111
+ #########################################################################
112
+ ### SESSION STATE INITIALIZATION
113
+ st.set_page_config(layout="wide")
114
+
115
+ if "pdf_ref" not in ss:
116
+ ss.input_pdf = []
117
+ if "doclist" not in ss:
118
+ ss.doclist = []
119
+ if "pdf_reader" not in ss:
120
+ ss.pdf_reader = None
121
+ if "pdf_postprocessor" not in ss:
122
+ ss.pdf_postprocessor = None
123
+ # if 'sentence_model' not in ss:
124
+ # ss.sentence_model = None # sentence splitting model, as alternative to nltk/PySBD
125
+ if "embed_model" not in ss:
126
+ ss.embed_model = None
127
+ gc.collect()
128
+ empty_cache()
129
+ if "reranker_model" not in ss:
130
+ ss.reranker_model = None
131
+ gc.collect()
132
+ empty_cache()
133
+ if "llm" not in ss:
134
+ ss.llm = None
135
+ gc.collect()
136
+ empty_cache()
137
+ if "multimodal_llm" not in ss:
138
+ ss.multimodal_llm = None
139
+ gc.collect()
140
+ empty_cache()
141
+ if "callback_manager" not in ss:
142
+ ss.callback_manager = None
143
+ if "node_parser" not in ss:
144
+ ss.node_parser = None
145
+ if "node_postprocessors" not in ss:
146
+ ss.node_postprocessors = None
147
+ if "response_synthesizer" not in ss:
148
+ ss.response_synthesizer = None
149
+ if "tree_summarizer" not in ss:
150
+ ss.tree_summarizer = None
151
+ if "citation_builder" not in ss:
152
+ ss.citation_builder = None
153
+ if "agent" not in ss:
154
+ ss.agent = None
155
+ if "observability" not in ss:
156
+ ss.observability = None
157
+
158
+ if "uploaded_files" not in ss:
159
+ ss.uploaded_files = []
160
+ if "selected_file" not in ss:
161
+ ss.selected_file = None
162
+
163
+ if "chat_messages" not in ss:
164
+ ss.chat_messages = []
165
+
166
+ ################################################################################
167
+ ### SCRIPT
168
+
169
+ st.markdown("""
170
+ <style>
171
+ .block-container {
172
+ padding-top: 3rem;
173
+ padding-bottom: 0rem;
174
+ padding-left: 3rem;
175
+ padding-right: 3rem;
176
+ }
177
+ </style>
178
+ """, unsafe_allow_html=True)
179
+
180
+ ### # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
181
+ ### UI
182
+ st.text("Autodoc Lifter Local PDF Chatbot (Built with Meta🦙3)")
183
+ col_left, col_right = st.columns([1, 1])
184
+
185
+ ### # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
186
+ ### PDF Upload UI (Left Panel)
187
+ with st.sidebar:
188
+ uploaded_files = st.file_uploader(
189
+ label="Upload a PDF file.",
190
+ type="pdf",
191
+ accept_multiple_files=True,
192
+ label_visibility="collapsed",
193
+ )
194
+
195
+ ### # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
196
+ ### PDF Display UI (Middle Panel)
197
+ # NOTE: This currently only displays the PDF, which requires user interaction (below)
198
+
199
+ ### # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
200
+ ### Chat UI (Right Panel)
201
+
202
+ with col_right:
203
+ messages_container = st.container(height=475, border=False)
204
+ input_container = st.container(height=80, border=False)
205
+
206
+ with messages_container:
207
+ for message in ss.chat_messages:
208
+ with st.chat_message(message["role"]):
209
+ st.markdown(message["content"])
210
+
211
+ with input_container:
212
+ # Accept user input
213
+ prompt = st.chat_input("Ask your question about the document here.")
214
+
215
+ ### # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
216
+ ### Get Models and Settings
217
+ # Get Vision LLM
218
+ if (ss.multimodal_llm is None):
219
+ print(f"CUDA Memory Pre-VLLM: {mem_get_info()}")
220
+ vision_llm = get_multimodal_llm()
221
+ ss.multimodal_llm = vision_llm
222
+
223
+ # Get LLM
224
+ if (ss.llm is None):
225
+ print(f"CUDA Memory Pre-LLM: {mem_get_info()}")
226
+ llm = get_llm()
227
+ ss.llm = llm
228
+ Settings.llm = cast(llm, BaseLLM)
229
+
230
+ # Get Sentence Splitting Model.
231
+ # if (ss.sentence_model is None):
232
+ # sent_splitter = get_sat_sentence_splitter('sat-3l-sm')
233
+ # ss.sentence_model = sent_splitter
234
+
235
+ # Get Embedding Model
236
+ if (ss.embed_model is None):
237
+ print(f"CUDA Memory Pre-Embedding: {mem_get_info()}")
238
+ embed_model = get_embedder()
239
+ ss.embed_model = embed_model
240
+ Settings.embed_model = embed_model
241
+
242
+ # Get Reranker
243
+ if (ss.reranker_model is None):
244
+ print(f"CUDA Memory Pre-Reranking: {mem_get_info()}")
245
+ ss.reranker_model = get_reranker()
246
+
247
+ # Get Callback Manager
248
+ if (ss.callback_manager is None):
249
+ callback_manager = get_callback_manager()
250
+ ss.callback_manager = callback_manager
251
+ Settings.callback_manager = callback_manager
252
+
253
+ # Get Node Parser
254
+ if (ss.node_parser is None):
255
+ node_parser = get_parser(
256
+ embed_model=Settings.embed_model,
257
+ callback_manager=ss.callback_manager
258
+ )
259
+ ss.node_parser = node_parser
260
+ Settings.node_parser = node_parser
261
+
262
+ #### Get Observability
263
+ if (ss.observability is None):
264
+ obs = get_obs()
265
+
266
+ ### Get PDF Reader
267
+ if (ss.pdf_reader is None):
268
+ ss.pdf_reader = UnstructuredPDFReader()
269
+
270
+ ### Get PDF Reader Postprocessing
271
+ if (ss.pdf_postprocessor is None):
272
+ # Get embedding
273
+ # regex_adder = RegexMetadataAdder(regex_pattern=) # Are there any that I need?
274
+ keyword_adder = KeywordMetadataAdder(metadata_name="keywords")
275
+ table_summary_adder = TableSummaryMetadataAdder(llm=ss.llm)
276
+ image_summary_adder = ImageSummaryMetadataAdder(llm=ss.multimodal_llm)
277
+
278
+ pdf_postprocessor = UnstructuredPDFPostProcessor(
279
+ embed_model=ss.embed_model,
280
+ metadata_adders=[keyword_adder, table_summary_adder, image_summary_adder]
281
+ )
282
+ ss.pdf_postprocessor = pdf_postprocessor
283
+
284
+ #### Get Observability
285
+ if (ss.observability is None):
286
+ ss.observability = get_obs()
287
+ observability = ss.observability
288
+
289
+ ### Get Node Postprocessor Pipeline
290
+ if (ss.node_postprocessors is None):
291
+ from nltk.tokenize import PunktTokenizer
292
+ punkt_tokenizer = PunktTokenizer()
293
+ ss.node_postprocessors = [
294
+ SimilarityPostprocessor(similarity_cutoff=0.01), # remove nodes unrelated to query
295
+ ss.reranker_model, # rerank
296
+ # remove sentences less related to query. lower is stricter
297
+ SentenceEmbeddingOptimizer(tokenizer_fn=punkt_tokenizer.tokenize, percentile_cutoff=0.2),
298
+ ]
299
+
300
+ ### Get Response Synthesizer
301
+ if (ss.response_synthesizer is None):
302
+ ss.response_synthesizer = get_response_synthesizer(
303
+ response_mode=ResponseMode.COMPACT,
304
+ text_qa_template=get_qa_prompt(),
305
+ refine_template=get_refine_prompt()
306
+ )
307
+
308
+ ### Get Tree Summarizer
309
+ if (ss.tree_summarizer is None):
310
+ ss.tree_summarizer = get_tree_summarizer()
311
+
312
+ ### Get Citation Builder
313
+ if (ss.citation_builder is None):
314
+ ss.citation_builder = get_citation_builder()
315
+
316
+ ### # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
317
+ ### Handle User Interaction
318
+ def handle_new_pdf(file_io: Any) -> None:
319
+ """Handle processing a new source PDF file document."""
320
+ with st.sidebar:
321
+ with (st.spinner("Reading input file, this make take some time...")):
322
+ ### Save Locally
323
+ # TODO(Jonathan Wang): Get the user to upload their file with a reference name in a separate tab.
324
+ if not Path(__file__).parent.joinpath("data").exists():
325
+ print("NEWPDF: Making data directory...")
326
+ Path(__file__).parent.joinpath("data").mkdir(parents=True)
327
+ with open(Path(__file__).parent.joinpath("data/input.pdf"), "wb") as f:
328
+ print("NEWPDF: Writing input file...")
329
+ f.write(file_io.getbuffer())
330
+
331
+ ### Create Document
332
+ print("NEWPDF: Building Document...")
333
+ new_document = FullDocument(
334
+ name="input.pdf",
335
+ file_path=Path(__file__).parent.joinpath("data/input.pdf"),
336
+ )
337
+
338
+ #### Process document.
339
+ print("NEWPDF: Writing input file...")
340
+ new_document.file_to_nodes(
341
+ reader=ss.pdf_reader,
342
+ postreaders=[
343
+ clean_abbreviations, dedupe_title_chunks, combine_listitem_chunks,
344
+ remove_header_footer_repeated, chunk_by_header
345
+ ],
346
+ node_parser=ss.node_parser,
347
+ postparsers=[ss.pdf_postprocessor],
348
+ )
349
+
350
+ ### Get Storage Context
351
+ with (st.spinner("Processing input file, this make take some time...")):
352
+ new_document.nodes_to_summary(summarizer=ss.tree_summarizer)
353
+ new_document.summary_to_oneline(summarizer=ss.tree_summarizer)
354
+ new_document.nodes_to_document_keywords()
355
+ new_document.nodes_to_storage()
356
+ ### Get Retrieval on Vector Store Index
357
+ with (st.spinner("Building retriever for the input file...")):
358
+ new_document.storage_to_retriever(callback_manager=ss.callback_manager)
359
+ ### Get LLM Query Engine
360
+ with (st.spinner("Building query responder for the input file...")):
361
+ new_document.retriever_to_engine(
362
+ response_synthesizer=ss.response_synthesizer,
363
+ callback_manager=ss.callback_manager
364
+ )
365
+ new_document.engine_to_sub_question_engine()
366
+
367
+ ### Officially Add to Document List
368
+ ss.uploaded_files.append(uploaded_file) # Left UI Bar
369
+ ss.doclist.append(new_document) # Document list for RAG. # TODO(Jonathan Wang): Fix potential duplication.
370
+
371
+ ### Get LLM Agent
372
+ with (st.spinner("Building LLM Agent for the input file...")):
373
+ agent = doclist_to_agent(ss.doclist)
374
+ ss.agent = agent
375
+
376
+ # All done!
377
+ st.toast("All done!")
378
+
379
+ # Display summary of new document in chat.
380
+ with messages_container:
381
+ ss.chat_messages.append(
382
+ {"role": "assistant", "content": new_document.summary_oneline}
383
+ )
384
+ with st.chat_message("assistant"):
385
+ st.markdown(new_document.summary_oneline)
386
+
387
+ ### Cleaning
388
+ empty_cache()
389
+ gc.collect()
390
+
391
+
392
+ def handle_chat_message(user_message: str) -> str:
393
+ # Get Response
394
+ if (not hasattr(ss, "doclist") or len(ss.doclist) == 0):
395
+ return "Please upload a document to get started."
396
+
397
+ if (not hasattr(ss, "agent")):
398
+ warnings.warn("No LLM Agent found. Attempting to create one.", stacklevel=2)
399
+ with st.sidebar, (st.spinner("Building LLM Agent for the input file...")):
400
+ agent = doclist_to_agent(ss.doclist)
401
+ ss.agent = agent
402
+
403
+ response = ss.agent.query(user_message)
404
+ # Get citations if available
405
+ response = ss.citation_builder.get_citations(response, citation_threshold=60)
406
+ # Add citations to response text
407
+ response_with_citations = ss.citation_builder.add_citations_to_response(response)
408
+ return str(response_with_citations.response)
409
+
410
+ @st.cache_data
411
+ def get_pdf_display(
412
+ file: Any,
413
+ app_width: str = "100%",
414
+ app_height: str = "500",
415
+ starting_page_number: int | None = None
416
+ ) -> str:
417
+ # Read file as binary
418
+ file_bytes = file.getbuffer()
419
+ base64_pdf = base64.b64encode(file_bytes).decode("utf-8")
420
+
421
+ pdf_display = f'<embed src="data:application/pdf;base64,{base64_pdf}"' # TODO(Jonathan Wang): iframe vs embed
422
+ if starting_page_number is not None:
423
+ pdf_display += f"#page={starting_page_number}"
424
+ pdf_display += f' width={app_width} height="{app_height}" type="application/pdf"></iembed>' # iframe vs embed
425
+ return (pdf_display)
426
+
427
+ # Upload
428
+ with st.sidebar:
429
+ uploaded_files = uploaded_files or [] # handle case when no file is uploaded
430
+ for uploaded_file in uploaded_files:
431
+ if (uploaded_file not in ss.uploaded_files):
432
+ handle_new_pdf(uploaded_file)
433
+
434
+ if (ss.selected_file is None and ss.uploaded_files):
435
+ ss.selected_file = ss.uploaded_files[-1]
436
+
437
+ file_names = [file.name for file in ss.uploaded_files]
438
+ selected_file_name = st.radio("Uploaded Files:", file_names)
439
+ if selected_file_name:
440
+ ss.selected_file = [file for file in ss.uploaded_files if file.name == selected_file_name][-1]
441
+
442
+ with col_left:
443
+ if (ss.selected_file is None):
444
+ selected_file_name = "Upload a file."
445
+ st.markdown(f"## {selected_file_name}")
446
+
447
+ elif (ss.selected_file is not None):
448
+ selected_file = ss.selected_file
449
+ selected_file_name = selected_file.name
450
+
451
+ if (selected_file.type == "application/pdf"):
452
+ pdf_display = get_pdf_display(selected_file, app_width="100%", app_height="550")
453
+ st.markdown(pdf_display, unsafe_allow_html=True)
454
+
455
+ # Chat
456
+ if prompt:
457
+ with messages_container:
458
+ with st.chat_message("user"):
459
+ st.markdown(prompt)
460
+ ss.chat_messages.append({"role": "user", "content": prompt})
461
+
462
+ with st.spinner("Generating response..."):
463
+ # Get Response
464
+ response = handle_chat_message(prompt)
465
+
466
+ if response:
467
+ ss.chat_messages.append(
468
+ {"role": "assistant", "content": response}
469
+ )
470
+ with st.chat_message("assistant"):
471
+ st.markdown(response)
citation.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ ### DOCUMENT PROCESSOR [CITATION]
3
+ #####################################################
4
+ # Jonathan Wang
5
+
6
+ # ABOUT:
7
+ # This project creates an app to chat with PDFs.
8
+
9
+ # This is the CITATION
10
+ # which adds citation information to the LLM response
11
+ #####################################################
12
+ ## TODO Board:
13
+ # Investigate using LLM model weights with attention to determien citations.
14
+
15
+ # https://gradientscience.org/contextcite/
16
+ # https://github.com/MadryLab/context-cite/blob/main/context_cite/context_citer.py#L25
17
+ # https://github.com/MadryLab/context-cite/blob/main/context_cite/context_partitioner.py
18
+ # https://github.com/MadryLab/context-cite/blob/main/context_cite/solver.py
19
+
20
+ #####################################################
21
+ ## IMPORTS
22
+ from __future__ import annotations
23
+
24
+ from collections import defaultdict
25
+ from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
26
+ import warnings
27
+
28
+ import numpy as np
29
+ from llama_index.core.base.response.schema import RESPONSE_TYPE, Response
30
+
31
+ if TYPE_CHECKING:
32
+ from llama_index.core.schema import NodeWithScore
33
+
34
+ # Own Modules
35
+ from merger import _merge_on_scores
36
+ from rapidfuzz import fuzz, process, utils
37
+
38
+
39
+ # Lazy Loading:
40
+ # from nltk import sent_tokenize # noqa: ERA001
41
+
42
+ #####################################################
43
+ ## CODE
44
+
45
+ class CitationBuilder:
46
+ """Class that builds citations from responses."""
47
+
48
+ text_splitter: Callable[[str], list[str]]
49
+
50
+ def __init__(self, text_splitter: Callable[[str], list[str]] | None = None) -> None:
51
+ if not text_splitter:
52
+ from nltk import sent_tokenize
53
+ text_splitter = sent_tokenize
54
+ self.text_splitter = text_splitter
55
+
56
+ @classmethod
57
+ def class_name(cls) -> str:
58
+ return "CitationBuilder"
59
+
60
+ def convert_to_response(self, input_response: RESPONSE_TYPE) -> Response:
61
+ # Convert all other response types into the baseline response
62
+ # Otherwise, we won't have the full response text generated.
63
+ if not isinstance(input_response, Response):
64
+ response = input_response.get_response()
65
+ if isinstance(response, Response):
66
+ return response
67
+ else:
68
+ # TODO(Jonathan Wang): Handle async responses with Coroutines
69
+ msg = "Expected Response object, got Coroutine"
70
+ raise TypeError(msg)
71
+ else:
72
+ return input_response
73
+
74
+ def find_nearest_whitespace(
75
+ self,
76
+ input_text: str,
77
+ input_index: int,
78
+ right_to_left: bool=False
79
+ ) -> int:
80
+ """Given a sting and an index, find the index of whitespace closest to the string."""
81
+ if (input_index < 0 or input_index >= len(input_text)):
82
+ msg = "find_nearest_whitespace: index beyond string."
83
+ raise ValueError(msg)
84
+
85
+ find_text = ""
86
+ if (right_to_left):
87
+ find_text = input_text[:input_index]
88
+ for index, char in enumerate(reversed(find_text)):
89
+ if (char.isspace()):
90
+ return (len(find_text)-1 - index)
91
+ return (0)
92
+ else:
93
+ find_text = input_text[input_index:]
94
+ for index, char in enumerate(find_text):
95
+ if (char.isspace()):
96
+ return (input_index + index)
97
+ return (len(input_text))
98
+
99
+ def get_citations(
100
+ self,
101
+ input_response: RESPONSE_TYPE,
102
+ citation_threshold: int = 70,
103
+ citation_len: int = 128
104
+ ) -> Response:
105
+ response = self.convert_to_response(input_response)
106
+
107
+ if not response.response or not response.source_nodes:
108
+ return response
109
+
110
+ # Get current response text:
111
+ response_text = response.response
112
+ source_nodes = response.source_nodes
113
+
114
+ # 0. Get candidate nodes for citation.
115
+ # Fuzzy match each source node text against the respone text.
116
+ source_texts: dict[str, list[NodeWithScore]] = defaultdict(list)
117
+ for node in source_nodes:
118
+ if (
119
+ (len(getattr(node.node, "text", "")) > 0) and
120
+ (len(node.node.metadata) > 0)
121
+ ): # filter out non-text nodes and intermediate nodes from SubQueryQuestionEngine
122
+ source_texts[node.node.text].append(node) # type: ignore
123
+
124
+ fuzzy_matches = process.extract(
125
+ response_text,
126
+ list(source_texts.keys()),
127
+ scorer=fuzz.partial_ratio,
128
+ processor=utils.default_process,
129
+ score_cutoff=max(10, citation_threshold - 10)
130
+ )
131
+
132
+ # Convert extracted matches of form (Match, Score, Rank) into scores for all source_texts.
133
+ if fuzzy_matches:
134
+ fuzzy_texts, _, _ = zip(*fuzzy_matches)
135
+ fuzzy_nodes = [source_texts[text][0] for text in fuzzy_texts]
136
+ else:
137
+ return response
138
+
139
+ # 1. Combine fuzzy score and source text semantic/reranker score.
140
+ # NOTE: for our merge here, we value the nodes with strong fuzzy text matching over other node types.
141
+ cited_nodes = _merge_on_scores(
142
+ a_list=fuzzy_nodes,
143
+ b_list=source_nodes, # same nodes, different scores (fuzzy vs semantic/bm25/reranker)
144
+ a_scores_input=[getattr(node, "score", np.nan) for node in fuzzy_nodes],
145
+ b_scores_input=[getattr(node, "score", np.nan) for node in source_nodes],
146
+ a_weight=0.85, # we want to heavily prioritize the fuzzy text for matches
147
+ top_k=3 # maximum of three source options.
148
+ )
149
+
150
+ # 2. Add cited nodes text to the response text, and cited nodes as metadata.
151
+ # For each sentence in the response, if there is a match in the source text, add a citation tag.
152
+ response_sentences = self.text_splitter(response_text)
153
+ output_text = ""
154
+ output_citations = ""
155
+ citation_tag = 0
156
+
157
+ for response_sentence in response_sentences:
158
+ # Get fuzzy citation at sentence level
159
+ best_alignment = None
160
+ best_score = 0
161
+ best_node = None
162
+
163
+ for _, source_node in enumerate(source_nodes):
164
+ source_node_text = getattr(source_node.node, "text", "")
165
+ new_alignment = fuzz.partial_ratio_alignment(
166
+ response_sentence,
167
+ source_node_text,
168
+ processor=utils.default_process, score_cutoff=citation_threshold
169
+ )
170
+ new_score = 0.0
171
+
172
+ if (new_alignment is not None and (new_alignment.src_end - new_alignment.src_start) > 0):
173
+ new_score = fuzz.ratio(
174
+ source_node_text[new_alignment.src_start:new_alignment.src_end],
175
+ response_sentence[new_alignment.dest_start:new_alignment.dest_end],
176
+ processor=utils.default_process
177
+ )
178
+ new_score = new_score * (new_alignment.src_end - new_alignment.src_start) / float(len(response_sentence))
179
+
180
+ if (new_score > best_score):
181
+ best_alignment = new_alignment
182
+ best_score = new_score
183
+ best_node = source_node
184
+
185
+ if (best_score <= 0 or best_node is None or best_alignment is None):
186
+ # No match
187
+ output_text += response_sentence
188
+ continue
189
+
190
+ # Add citation tag to text
191
+ citation_tag_position = self.find_nearest_whitespace(response_sentence, best_alignment.dest_start, right_to_left=True)
192
+ output_text += response_sentence[:citation_tag_position] # response up to the quote
193
+ output_text += f" [{citation_tag}] " # add citation tag
194
+ output_text += response_sentence[citation_tag_position:] # reposnse after the quote
195
+
196
+ # Add citation text to citations
197
+ citation = getattr(best_node.node, "text", "")
198
+ citation_margin = round((citation_len - (best_alignment.src_end - best_alignment.src_start)) / 2)
199
+ nearest_whitespace_pre = self.find_nearest_whitespace(citation, max(0, best_alignment.src_start), right_to_left=True)
200
+ nearest_whitespace_post = self.find_nearest_whitespace(citation, min(len(citation)-1, best_alignment.src_end), right_to_left=False)
201
+ nearest_whitespace_prewindow = self.find_nearest_whitespace(citation, max(0, nearest_whitespace_pre - citation_margin), right_to_left=True)
202
+ nearest_whitespace_postwindow = self.find_nearest_whitespace(citation, min(len(citation)-1, nearest_whitespace_post + citation_margin), right_to_left=False)
203
+
204
+ citation_text = (
205
+ citation[nearest_whitespace_prewindow+1: nearest_whitespace_pre+1]
206
+ + "|||||"
207
+ + citation[nearest_whitespace_pre+1:nearest_whitespace_post]
208
+ + "|||||"
209
+ + citation[nearest_whitespace_post:nearest_whitespace_postwindow]
210
+ + f"… <<{best_node.node.metadata.get('name', '')}, Page(s) {best_node.node.metadata.get('page_number', '')}>>"
211
+ )
212
+ output_citations += f"[{citation_tag}]: {citation_text}\n\n"
213
+ citation_tag += 1
214
+
215
+ # Create output
216
+ if response.metadata is not None:
217
+ # NOTE: metadata is certainly existant by now, but the schema allows None...
218
+ response.metadata["cited_nodes"] = cited_nodes
219
+ response.metadata["citations"] = output_citations
220
+ response.response = output_text # update response to include citation tags
221
+ return response
222
+
223
+ def add_citations_to_response(self, input_response: Response) -> Response:
224
+ if not hasattr(input_response, "metadata"):
225
+ msg = "Input response does not have metadata."
226
+ raise ValueError(msg)
227
+ elif input_response.metadata is None or "citations" not in input_response.metadata:
228
+ warnings.warn("Input response does not have citations.", stacklevel=2)
229
+ input_response = self.get_citations(input_response)
230
+
231
+ # Add citation text to response
232
+ if (hasattr(input_response, "metadata") and input_response.metadata.get("citations", "") != ""):
233
+ input_response.response = (
234
+ input_response.response
235
+ + "\n\n----- CITATIONS -----\n\n"
236
+ + input_response.metadata.get('citations', "")
237
+ ) # type: ignore
238
+ return input_response
239
+
240
+ def __call__(self, input_response: RESPONSE_TYPE, *args: Any, **kwds: Any) -> Response:
241
+ return self.get_citations(input_response, *args, **kwds)
242
+
243
+
244
+ def get_citation_builder() -> CitationBuilder:
245
+ return CitationBuilder()
engine.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ ### DOCUMENT PROCESSOR [ENGINE]
3
+ #####################################################
4
+ # Jonathan Wang
5
+
6
+ # ABOUT:
7
+ # This project creates an app to chat with PDFs.
8
+
9
+ # This is the ENGINE
10
+ # which defines how LLMs handle processing.
11
+ #####################################################
12
+ ## TODO Board:
13
+
14
+ #####################################################
15
+ ## IMPORTS
16
+ from __future__ import annotations
17
+
18
+ import gc
19
+ from typing import TYPE_CHECKING, Callable, List, Optional, cast
20
+
21
+ from llama_index.core.query_engine import CustomQueryEngine
22
+ from llama_index.core.schema import NodeWithScore, QueryBundle
23
+ from llama_index.core.settings import (
24
+ Settings,
25
+ )
26
+ from torch.cuda import empty_cache
27
+
28
+ if TYPE_CHECKING:
29
+ from llama_index.core.base.response.schema import Response
30
+ from llama_index.core.callbacks import CallbackManager
31
+ from llama_index.core.postprocessor.types import BaseNodePostprocessor
32
+ from llama_index.core.response_synthesizers import (
33
+ BaseSynthesizer,
34
+ )
35
+ from llama_index.core.retrievers import BaseRetriever
36
+
37
+ # Own Modules
38
+
39
+ #####################################################
40
+ ## CODE
41
+ class RAGQueryEngine(CustomQueryEngine):
42
+ """Custom RAG Query Engine."""
43
+
44
+ retriever: BaseRetriever
45
+ response_synthesizer: BaseSynthesizer
46
+ node_postprocessors: Optional[List[BaseNodePostprocessor]] = []
47
+
48
+ # def __init__(
49
+ # self,
50
+ # retriever: BaseRetriever,
51
+ # response_synthesizer: Optional[BaseSynthesizer] = None,
52
+ # node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
53
+ # callback_manager: Optional[CallbackManager] = None,
54
+ # ) -> None:
55
+ # self._retriever = retriever
56
+ # # callback_manager = (
57
+ # # callback_manager
58
+ # # Settings.callback_manager
59
+ # # )
60
+ # # llm = llm or Settings.llm
61
+
62
+ # self._response_synthesizer = response_synthesizer or get_response_synthesizer(
63
+ # # llm=llm,
64
+ # # service_context=service_context,
65
+ # # callback_manager=callback_manager,
66
+ # )
67
+ # self._node_postprocessors = node_postprocessors or []
68
+ # self._metadata_mode = metadata_mode
69
+
70
+ # for node_postprocessor in self._node_postprocessors:
71
+ # node_postprocessor.callback_manager = callback_manager
72
+
73
+ # super().__init__(callback_manager=callback_manager)
74
+
75
+ @classmethod
76
+ def class_name(cls) -> str:
77
+ """Class name."""
78
+ return "RAGQueryEngine"
79
+
80
+ # taken from Llamaindex CustomEngine:
81
+ # https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/query_engine/retriever_query_engine.py#L134
82
+ def _apply_node_postprocessors(
83
+ self, nodes: list[NodeWithScore], query_bundle: QueryBundle
84
+ ) -> list[NodeWithScore]:
85
+ if self.node_postprocessors is None:
86
+ return nodes
87
+
88
+ for node_postprocessor in self.node_postprocessors:
89
+ nodes = node_postprocessor.postprocess_nodes(
90
+ nodes, query_bundle=query_bundle
91
+ )
92
+ return nodes
93
+
94
+ def retrieve(self, query_bundle: QueryBundle) -> list[NodeWithScore]:
95
+ nodes = self.retriever.retrieve(query_bundle)
96
+ return self._apply_node_postprocessors(nodes, query_bundle=query_bundle)
97
+
98
+ async def aretrieve(self, query_bundle: QueryBundle) -> list[NodeWithScore]:
99
+ nodes = await self.retriever.aretrieve(query_bundle)
100
+ return self._apply_node_postprocessors(nodes, query_bundle=query_bundle)
101
+
102
+ def custom_query(self, query_str: str) -> Response:
103
+ # Convert query string into query bundle
104
+ query_bundle = QueryBundle(query_str=query_str)
105
+ nodes = self.retrieve(query_bundle) # also does the postprocessing.
106
+
107
+ response_obj = self.response_synthesizer.synthesize(query_bundle, nodes)
108
+
109
+ empty_cache()
110
+ gc.collect()
111
+ return cast(Response, response_obj) # type: ignore
112
+
113
+
114
+ # @st.cache_resource # none of these can be hashable or cached :(
115
+ def get_engine(
116
+ retriever: BaseRetriever,
117
+ response_synthesizer: BaseSynthesizer,
118
+ node_postprocessors: list[BaseNodePostprocessor] | None = None,
119
+ callback_manager: CallbackManager | None = None,
120
+ ) -> RAGQueryEngine:
121
+ return RAGQueryEngine(
122
+ retriever=retriever,
123
+ response_synthesizer=response_synthesizer,
124
+ node_postprocessors=node_postprocessors,
125
+ callback_manager=callback_manager or Settings.callback_manager,
126
+ )
full_doc.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ ### DOCUMENT PROCESSOR [FULLDOC]
3
+ #####################################################
4
+ ### Jonathan Wang
5
+
6
+ # ABOUT:
7
+ # This creates an app to chat with PDFs.
8
+
9
+ # This is the FULLDOC
10
+ # which is a class that associates documents
11
+ # with their critical information
12
+ # and their tools. (keywords, summary, queryengine, etc.)
13
+ #####################################################
14
+ ### TODO Board:
15
+ # Automatically determine which reader to use for each document based on the file type.
16
+
17
+ #####################################################
18
+ ### PROGRAM SETTINGS
19
+
20
+ #####################################################
21
+ ### PROGRAM IMPORTS
22
+ from __future__ import annotations
23
+
24
+ import asyncio
25
+ from pathlib import Path
26
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar
27
+ from uuid import UUID, uuid4
28
+
29
+ from llama_index.core import StorageContext, VectorStoreIndex
30
+ from llama_index.core.query_engine import SubQuestionQueryEngine
31
+ from llama_index.core.schema import BaseNode, TransformComponent
32
+ from llama_index.core.settings import Settings
33
+ from llama_index.core.tools import QueryEngineTool, ToolMetadata
34
+ from streamlit import session_state as ss
35
+
36
+ if TYPE_CHECKING:
37
+ from llama_index.core.base.base_query_engine import BaseQueryEngine
38
+ from llama_index.core.callbacks import CallbackManager
39
+ from llama_index.core.node_parser import NodeParser
40
+ from llama_index.core.readers.base import BaseReader
41
+ from llama_index.core.response_synthesizers import BaseSynthesizer
42
+ from llama_index.core.retrievers import BaseRetriever
43
+
44
+ # Own Modules
45
+ from engine import get_engine
46
+ from keywords import KeywordMetadataAdder
47
+ from retriever import get_retriever
48
+ from storage import get_docstore, get_vector_store
49
+ from summary import DEFAULT_ONELINE_SUMMARY_TEMPLATE, DEFAULT_TREE_SUMMARY_TEMPLATE
50
+
51
+ #####################################################
52
+ ### SCRIPT
53
+
54
+ GenericNode = TypeVar("GenericNode", bound=BaseNode)
55
+
56
+ class FullDocument:
57
+ """Bundles all the information about a document together.
58
+
59
+ Args:
60
+ name (str): The name of the document.
61
+ file_path (Path): The path to the document.
62
+ summary (str): The summary of the document.
63
+ keywords (List[str]): The keywords of the document.
64
+ entities (List[str]): The entities of the document.
65
+ vector_store (BaseDocumentStore): The vector store of the document.
66
+ """
67
+
68
+ # Identifiers
69
+ id: UUID
70
+ name: str
71
+ file_path: Path
72
+ file_name: str
73
+
74
+ # Basic Contents
75
+ summary: str
76
+ summary_oneline: str # A one line summary of the document.
77
+ keywords: set[str] # List of keywords in document.
78
+ # entities: Set[str] # list of entities in document ## TODO: Add entities
79
+ metadata: dict[str, Any] | None
80
+ # NOTE: other metdata that might be useful:
81
+ # Document Creation / Last Date (e.g., recency important for legal/medical questions)
82
+ # Document Source and Trustworthiness
83
+ # Document Access Level (though this isn't important for us here.)
84
+ # Document Citations?
85
+ # Document Format? (text/spreadsheet/presentation/image/etc.)
86
+
87
+ # RAG Components
88
+ nodes: list[BaseNode]
89
+ storage_context: StorageContext # NOTE: current setup has single storage context per document.
90
+ vector_store_index: VectorStoreIndex
91
+ retriever: BaseRetriever # TODO(Jonathan Wang): Consider multiple retrievers for keywords vs semantic.
92
+ engine: BaseQueryEngine # TODO(Jonathan Wang): Consider mulitple engines.
93
+ subquestion_engine: SubQuestionQueryEngine
94
+
95
+ def __init__(
96
+ self,
97
+ name: str,
98
+ file_path: Path | str,
99
+ metadata: dict[str, Any] | None = None
100
+ ) -> None:
101
+ self.id = uuid4()
102
+ self.name = name
103
+
104
+ if (isinstance(file_path, str)):
105
+ file_path = Path(file_path)
106
+ self.file_path = file_path
107
+ self.file_name = file_path.name
108
+
109
+ self.metadata = metadata
110
+
111
+
112
+ @classmethod
113
+ def class_name(cls) -> str:
114
+ return "FullDocument"
115
+
116
+ def add_name_to_nodes(self, nodes: list[GenericNode]) -> list[GenericNode]:
117
+ """Add the name of the document to the nodes.
118
+
119
+ Args:
120
+ nodes (List[GenericNode]): The nodes to add the name to.
121
+
122
+ Returns:
123
+ List[GenericNode]: The nodes with the name added.
124
+ """
125
+ for node in nodes:
126
+ node.metadata["name"] = self.name
127
+ return nodes
128
+
129
+ def file_to_nodes(
130
+ self,
131
+ reader: BaseReader,
132
+ postreaders: list[Callable[[list[GenericNode]], list[GenericNode]] | TransformComponent] | None=None, # NOTE: these should be used in order. and probably all TransformComponent instead.
133
+ node_parser: NodeParser | None=None,
134
+ postparsers: list[Callable[[list[GenericNode]], list[GenericNode]] | TransformComponent] | None=None, # Stuff like chunking, adding Embeddings, etc.
135
+ ) -> None:
136
+ """Read in the file path and get the nodes.
137
+
138
+ Args:
139
+ file_path (Optional[Path], optional): The path to the file. Defaults to file_path from init.
140
+ reader (Optional[BaseReader], optional): The reader to use. Defaults to reader from init.
141
+ """
142
+ # Use the provided reader to read in the file.
143
+ print("NEWPDF: Reading input file...")
144
+ nodes = reader.load_data(file_path=self.file_path)
145
+
146
+ # Use node postreaders to post process the nodes.
147
+ if (postreaders is not None):
148
+ for node_postreader in postreaders:
149
+ nodes = node_postreader(nodes) # type: ignore (TransformComponent allows a list of nodes)
150
+
151
+ # Use node parser to parse the nodes.
152
+ if (node_parser is None):
153
+ node_parser = Settings.node_parser
154
+ nodes = node_parser(nodes) # type: ignore (Document is a child of BaseNode)
155
+
156
+ # Use node postreaders to post process the nodes. (also add the common name to the nodes)
157
+ if (postparsers is None):
158
+ postparsers = [self.add_name_to_nodes]
159
+ else:
160
+ postparsers.append(self.add_name_to_nodes)
161
+
162
+ for node_postparser in postparsers:
163
+ nodes = node_postparser(nodes) # type: ignore (TransformComponent allows a list of nodes)
164
+
165
+ # Save nodes
166
+ self.nodes = nodes # type: ignore
167
+
168
+ def nodes_to_summary(
169
+ self,
170
+ summarizer: BaseSynthesizer, # NOTE: this is typically going to be a TreeSummarizer / SimpleSummarize for our use case
171
+ query_str: str = DEFAULT_TREE_SUMMARY_TEMPLATE,
172
+ ) -> None:
173
+ """Summarize the nodes.
174
+
175
+ Args:
176
+ summarizer (BaseSynthesizer): The summarizer to use. Takes in nodes and returns summary.
177
+ """
178
+ if (not hasattr(self, "nodes")):
179
+ msg = "Nodes must be extracted from document using `file_to_nodes` before calling `nodes_to_summary`."
180
+ raise ValueError(msg)
181
+
182
+ text_chunks = [getattr(node, "text", "") for node in self.nodes if hasattr(node, "text")]
183
+ summary_responses = summarizer.aget_response(query_str=query_str, text_chunks=text_chunks)
184
+
185
+ loop = asyncio.get_event_loop()
186
+ summary = loop.run_until_complete(summary_responses)
187
+
188
+ if (not isinstance(summary, str)):
189
+ # TODO(Jonathan Wang): ... this should always give us a string, right? we're not doing anything fancy with TokenGen/TokenAsyncGen/Pydantic BaseModel...
190
+ msg = f"Summarizer must return a string summary. Actual type: {type(summary)}, with value {summary}."
191
+ raise TypeError(msg)
192
+
193
+ self.summary = summary
194
+
195
+ def summary_to_oneline(
196
+ self,
197
+ summarizer: BaseSynthesizer, # NOTE: this is typically going to be a SimpleSummarize / TreeSummarizer for our use case
198
+ query_str: str = DEFAULT_ONELINE_SUMMARY_TEMPLATE,
199
+ ) -> None:
200
+
201
+ if (not hasattr(self, "summary")):
202
+ msg = "Summary must be extracted from document using `nodes_to_summary` before calling `summary_to_oneline`."
203
+ raise ValueError(msg)
204
+
205
+ oneline = summarizer.get_response(query_str=query_str, text_chunks=[self.summary]) # There's only one chunk.
206
+ self.summary_oneline = oneline # type: ignore | shouldn't have fancy TokenGenerators / TokenAsyncGenerators / Pydantic BaseModels
207
+
208
+ def nodes_to_document_keywords(self, keyword_extractor: Optional[KeywordMetadataAdder] = None) -> None:
209
+ """Save the keywords from the nodes into the document.
210
+
211
+ Args:
212
+ keyword_extractor (Optional[BaseKeywordExtractor], optional): The keyword extractor to use. Defaults to None.
213
+ """
214
+ if (not hasattr(self, "nodes")):
215
+ msg = "Nodes must be extracted from document using `file_to_nodes` before calling `nodes_to_keywords`."
216
+ raise ValueError(msg)
217
+
218
+ if (keyword_extractor is None):
219
+ keyword_extractor = KeywordMetadataAdder()
220
+
221
+ # Add keywords to nodes using KeywordMetadataAdder
222
+ keyword_extractor.process_nodes(self.nodes)
223
+
224
+ # Save keywords
225
+ keywords: list[str] = []
226
+ for node in self.nodes:
227
+ node_keywords = node.metadata.get("keyword_metadata", "").split(", ") # NOTE: KeywordMetadataAdder concatinates b/c required string output
228
+ keywords = keywords + node_keywords
229
+
230
+ # TODO(Jonathan Wang): handle dedupling keywords which are similar to each other (fuzzy?)
231
+ self.keywords = set(keywords)
232
+
233
+ def nodes_to_storage(self, create_new_storage: bool = True) -> None:
234
+ """Save the nodes to storage."""
235
+ if (not hasattr(self, "nodes")):
236
+ msg = "Nodes must be extracted from document using `file_to_nodes` before calling `nodes_to_storage`."
237
+ raise ValueError(msg)
238
+
239
+ if (create_new_storage):
240
+ docstore = get_docstore(documents=self.nodes)
241
+ self.docstore = docstore
242
+
243
+ vector_store = get_vector_store()
244
+
245
+ storage_context = StorageContext.from_defaults(
246
+ docstore=docstore,
247
+ vector_store=vector_store
248
+ )
249
+ self.storage_context = storage_context
250
+
251
+ vector_store_index = VectorStoreIndex(
252
+ self.nodes, storage_context=storage_context
253
+ )
254
+ self.vector_store_index = vector_store_index
255
+
256
+ else:
257
+ ### TODO(Jonathan Wang): use an existing storage instead of creating a new one.
258
+ msg = "Currently creates new storage for every document."
259
+ raise NotImplementedError(msg)
260
+
261
+ # TODO(Jonathan Wang): Create multiple different retrievers based on the question type(?)
262
+ # E.g., if the question is focused on specific keywords or phrases, use a retriever oriented towards sparse scores.
263
+ def storage_to_retriever(
264
+ self,
265
+ semantic_nodes: int = 6,
266
+ sparse_nodes: int = 3,
267
+ fusion_nodes: int = 3,
268
+ semantic_weight: float = 0.6,
269
+ merge_up_thresh: float = 0.5,
270
+ callback_manager: CallbackManager | None=None
271
+ ) -> None:
272
+ """Create retriever from storage."""
273
+ if (not hasattr(self, "vector_store_index")):
274
+ msg = "Vector store must be extracted from document using `nodes_to_storage` before calling `storage_to_retriever`."
275
+ raise ValueError(msg)
276
+
277
+ retriever = get_retriever(
278
+ _vector_store_index=self.vector_store_index,
279
+ semantic_top_k=semantic_nodes,
280
+ sparse_top_k=sparse_nodes,
281
+ fusion_similarity_top_k=fusion_nodes,
282
+ semantic_weight_fraction=semantic_weight,
283
+ merge_up_thresh=merge_up_thresh,
284
+ verbose=True,
285
+ _callback_manager=callback_manager or ss.callback_manager
286
+ )
287
+ self.retriever = retriever
288
+
289
+ def retriever_to_engine(
290
+ self,
291
+ response_synthesizer: BaseSynthesizer,
292
+ callback_manager: CallbackManager | None=None
293
+ ) -> None:
294
+ """Create query engine from retriever."""
295
+ if (not hasattr(self, "retriever")):
296
+ msg = "Retriever must be extracted from document using `storage_to_retriever` before calling `retriver_to_engine`."
297
+ raise ValueError(msg)
298
+
299
+ engine = get_engine(
300
+ retriever=self.retriever,
301
+ response_synthesizer=response_synthesizer,
302
+ callback_manager=callback_manager or ss.callback_manager
303
+ )
304
+ self.engine = engine
305
+
306
+ # TODO(Jonathan Wang): Create Summarization Index and Engine.
307
+ def engine_to_sub_question_engine(self) -> None:
308
+ """Convert a basic query engine into a sub-question query engine for handling complex, multi-step questions.
309
+
310
+ Args:
311
+ query_engine (BaseQueryEngine): The Base Query Engine to convert.
312
+ """
313
+ if (not hasattr(self, "summary_oneline")):
314
+ msg = "One Line Summary must be created for the document before calling `engine_to_sub_query_engine`"
315
+ raise ValueError(msg)
316
+ elif (not hasattr(self, "engine")):
317
+ msg = "Basic Query Engine must be created before calling `engine_to_sub_query_engine`"
318
+ raise ValueError(msg)
319
+
320
+ sqe_tools = [
321
+ QueryEngineTool(
322
+ query_engine=self.engine, # TODO(Jonathan Wang): handle mulitple engines?
323
+ metadata=ToolMetadata(
324
+ name=(self.name + "simple query answerer"),
325
+ description=f"""A tool that answers simple questions about the following document: {self.summary_oneline}"""
326
+ )
327
+ )
328
+ # TODO(Jonathan Wang): add more tools
329
+ ]
330
+
331
+ subquestion_engine = SubQuestionQueryEngine.from_defaults(
332
+ query_engine_tools=sqe_tools,
333
+ verbose=True,
334
+ use_async=True
335
+ )
336
+ self.subquestion_engine = subquestion_engine
keywords.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ ### DOCUMENT PROCESSOR [Keywords]
3
+ #####################################################
4
+ ### Jonathan Wang
5
+
6
+ # ABOUT:
7
+ # This creates an app to chat with PDFs.
8
+
9
+ # This is the Keywords
10
+ # Which creates keywords based on documents.
11
+ #####################################################
12
+ ### TODO Board:
13
+ # TODO(Jonathan Wang): Add Maximum marginal relevance to the merger for better keywords.
14
+ # TODO(Jonathan Wang): create own version of Rake keywords
15
+
16
+ #####################################################
17
+ ### PROGRAM SETTINGS
18
+
19
+
20
+ #####################################################
21
+ ### PROGRAM IMPORTS
22
+ from __future__ import annotations
23
+
24
+ from typing import Any, Callable, Optional
25
+
26
+ # Keywords
27
+ # from multi_rake import Rake # removing because of compile issues and lack of maintainence
28
+ import yake
29
+ from llama_index.core.bridge.pydantic import Field
30
+ from llama_index.core.schema import BaseNode
31
+
32
+ # Own Modules
33
+ from metadata_adder import MetadataAdder
34
+
35
+ #####################################################
36
+ ### SCRIPT
37
+
38
+ def get_keywords(input_text: str) -> str:
39
+ """
40
+ Given a string, get its keywords using RAKE+YAKE w/ Distribution Based Fusion.
41
+
42
+ Inputs:
43
+ input_text (str): the input text to get keywords from
44
+ # top_k (int): the number of keywords to get
45
+
46
+ Returns:
47
+ str: A list of the keywords, joined into a string.
48
+ """
49
+ # RAKE
50
+ # kw_extractor = Rake()
51
+ # keywords_rake = kw_extractor.apply(input_text)
52
+ # keywords_rake = dict(keywords_rake)
53
+ # YAKE
54
+ kw_extractor = yake.KeywordExtractor(lan="en", dedupLim=0.9, n=3)
55
+ keywords_yake = kw_extractor.extract_keywords(input_text)
56
+ # reorder scores so that higher is better
57
+ keywords_yake = {keyword[0].lower(): (1 - keyword[1]) for keyword in keywords_yake}
58
+ keywords_yake = dict(
59
+ sorted(keywords_yake.items(), key=lambda x: x[1], reverse=True) # type hinting YAKE is miserable
60
+ )
61
+
62
+ # Merge RAKE and YAKE based on scores.
63
+ # keywords_merged = _merge_on_scores(
64
+ # list(keywords_yake.keys()),
65
+ # list(keywords_rake.keys()),
66
+ # list(keywords_yake.values()),
67
+ # list(keywords_rake.values()),
68
+ # a_weight=0.5,
69
+ # top_k=top_k
70
+ # )
71
+
72
+ # return (list(keywords_rake.keys())[:top_k], list(keywords_yake.keys())[:top_k], keywords_merged)
73
+ return ", ".join(keywords_yake) # kinda regretting forcing this into a string
74
+
75
+
76
+ class KeywordMetadataAdder(MetadataAdder):
77
+ """Adds keyword metadata to a document.
78
+
79
+ Args:
80
+ metadata_name: The name of the metadata to add to the document. Defaults to 'keyword_metadata'.
81
+ keywords_function: A function for keywords, given a source string and the number of keywords to get.
82
+ """
83
+
84
+ keywords_function: Callable[[str, int], str] = Field(
85
+ description="The function to use to extract keywords from the text. Input is string and number of keywords to extract. Ouptut is string of keywords.",
86
+ default=get_keywords,
87
+ )
88
+ num_keywords: int = Field(
89
+ default=5,
90
+ description="The number of keywords to extract from the text. Defaults to 5.",
91
+ )
92
+
93
+ def __init__(
94
+ self,
95
+ metadata_name: str = "keyword_metadata",
96
+ keywords_function: Callable[[str], str] = get_keywords,
97
+ num_keywords: int = 5,
98
+ **kwargs: Any,
99
+ ) -> None:
100
+ """Init params."""
101
+ super().__init__(metadata_name=metadata_name, keywords_function=keywords_function, num_keywords=num_keywords, **kwargs) # ah yes i love oop :)
102
+
103
+ @classmethod
104
+ def class_name(cls) -> str:
105
+ return "KeywordMetadataAdder"
106
+
107
+ def get_node_metadata(self, node: BaseNode) -> str | None:
108
+ if not hasattr(node, "text") or node.text is None:
109
+ return None
110
+ return self.keywords_function(node.get_content(), self.num_keywords)
merger.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ ### DOCUMENT PROCESSOR [MERGER]
3
+ #####################################################
4
+ # Jonathan Wang
5
+
6
+ # ABOUT:
7
+ # This project creates an app to chat with PDFs.
8
+
9
+ # This is the MERGER
10
+ # which defines how two lists with scores
11
+ # should be merged together into one list.
12
+ # (Useful for fusing things like keywords or textnodes)
13
+ #####################################################
14
+ ## TODOS:
15
+ # We're looping through A/B more than necessary.
16
+
17
+ #####################################################
18
+ ## IMPORTS:
19
+ from __future__ import annotations
20
+
21
+ from typing import TYPE_CHECKING, Sequence, TypeVar, Union
22
+
23
+ import numpy as np
24
+
25
+ if TYPE_CHECKING:
26
+ from numpy.typing import NDArray
27
+
28
+ #####################################################
29
+ ## CODE:
30
+
31
+ GenericType = TypeVar("GenericType")
32
+
33
+ ### TODO(Jonathan Wang): Implement Maximum Marginal Relevance (MMR)
34
+ # https://en.wikipedia.org/wiki/Maximum_marginal_relevance
35
+ # def mmr(documents, query, scores, lambda_param=0.5):
36
+ # """
37
+ # Calculate Maximum Marginal Relevance (MMR) for a list of documents.
38
+
39
+ # Parameters:
40
+ # documents (list of np.array): List of document vectors.
41
+ # query (np.array): Query vector.
42
+ # scores (list of float): Relevance scores for each document.
43
+ # lambda_param (float): Trade-off parameter between relevance and diversity.
44
+
45
+ # Returns:
46
+ # list of int: Indices of selected documents in order of selection.
47
+ # """
48
+ # selected = []
49
+ # remaining = list(range(len(documents)))
50
+
51
+ # while remaining:
52
+ # if not selected:
53
+ # # Select the document with the highest relevance score
54
+ # idx = np.argmax(scores)
55
+ # else:
56
+ # # Calculate MMR for remaining documents
57
+ # mmr_scores = []
58
+ # for i in remaining:
59
+ # relevance = scores[i]
60
+ # diversity = max([np.dot(documents[i], documents[j]) for j in selected])
61
+ # mmr_score = lambda_param * relevance - (1 - lambda_param) * diversity
62
+ # mmr_scores.append(mmr_score)
63
+ # idx = remaining[np.argmax(mmr_scores)]
64
+
65
+ # selected.append(idx)
66
+ # remaining.remove(idx)
67
+
68
+ # return selected
69
+
70
+ def _merge_on_scores(
71
+ a_list: Sequence[GenericType],
72
+ b_list: Sequence[GenericType],
73
+ a_scores_input: Sequence[float | np.float64 | None],
74
+ b_scores_input: Sequence[float | np.float64 | None],
75
+ use_distribution: bool = True,
76
+ a_weight: float = 0.5,
77
+ top_k: int = 5,
78
+ ) -> Sequence[GenericType]:
79
+ """
80
+ Given two lists of elements with scores, fuse them together using "Distribution-Based Score Fusion".
81
+
82
+ Elements which have high scores in both lists are given even higher ranking here.
83
+
84
+ Inputs:
85
+ a_list: list of elements for A
86
+ a_scores: list of scores for each element in A. Assume higher is better. Share the same index.
87
+ b_list: list of elements for B
88
+ b_scores: list of scores for each element in B. Assume higher is better. Share the same index.
89
+ use_distribution: Whether to fuse using Min-Max Scaling (FALSE) or Distribution Based Score Fusion (TRUE)
90
+
91
+ Outputs:
92
+ List: List of elements that passed the merge.
93
+ """
94
+ # Guard Clauses
95
+ if ((len(a_list) != len(a_scores_input)) or (len(b_list) != len(b_scores_input))):
96
+ msg = (
97
+ f"""_merge_on_scores: Differing number of elements and scores!
98
+ a_list: {a_list}
99
+ a_scores: {a_scores_input}
100
+ b_list: {b_list}
101
+ b_scores: {b_scores_input}
102
+ """
103
+ )
104
+ raise ValueError(msg)
105
+
106
+ if (a_weight > 1 or a_weight < 0):
107
+ msg = "_merge_on_scores: weight for the A list should be between 0 and 1."
108
+ raise ValueError(msg)
109
+ if (top_k < 0): # or top_k > :
110
+ # TODO(Jonathan Wang): Find a nice way to get the number of unique elements in a list
111
+ # where those elements are potentially unhashable AND unorderable.
112
+ # I know about the n^2 solution with two lists and (if not in x), but it's a bit annoying.
113
+ msg = "_merge_on_scores: top_k must be between 0 and the total number of elements."
114
+ raise ValueError(msg)
115
+
116
+ # 0. Convert to numpy arrays
117
+ # NOTE: When using a SubQuestionQueryEngine, the subanswers are saved as NodesWithScores, but their score is None.
118
+ # We want to filter these out, so we get citations when the two texts are very similar.
119
+ a_scores: NDArray[np.float64] = np.array(a_scores_input, dtype=np.float64)
120
+ b_scores: NDArray[np.float64] = np.array(b_scores_input, dtype=np.float64)
121
+
122
+ # 1. Calculate mean of scores.
123
+ a_mean = np.nanmean(a_scores) # np.nan if empty
124
+ b_mean = np.nanmean(b_scores)
125
+
126
+ # 2. Calculate standard deviations
127
+ a_stdev = np.nanstd(a_scores)
128
+ b_stdev = np.nanstd(b_scores)
129
+
130
+ # 3. Get minimum and maximum bands as 3std from mean
131
+ # alternatively, use actual min-max scaling
132
+ a_min = a_mean - 3 * a_stdev if use_distribution else np.nanmin(a_scores)
133
+ a_max = a_mean + 3 * a_stdev if use_distribution else np.nanmax(a_scores)
134
+ b_min = b_mean - 3 * b_stdev if use_distribution else np.nanmin(b_scores)
135
+ b_max = b_mean + 3 * b_stdev if use_distribution else np.nanmax(b_scores)
136
+
137
+ # 4. Rescale the distributions
138
+ if (a_max > a_min):
139
+ a_scores = np.array([
140
+ ((x - a_min) / (a_max - a_min))
141
+ for x in a_scores
142
+ ], dtype=np.float64)
143
+ if (b_max > b_min):
144
+ b_scores = np.array([
145
+ (x - b_min) / (b_max - b_min)
146
+ for x in b_scores
147
+ ], dtype=np.float64)
148
+
149
+ # 5. Fuse the scores together
150
+ full_dict: list[tuple[GenericType, float]] = []
151
+ for index, element in enumerate(a_list):
152
+ a_score = a_scores[index]
153
+ if (element in b_list):
154
+ # In both A and B. Fuse score.
155
+ b_score = b_scores[b_list.index(element)]
156
+ fused_score = a_weight * a_score + (1-a_weight) * b_score
157
+ full_dict.append((element, fused_score))
158
+ else:
159
+ # Only in A.
160
+ full_dict.append((element, a_weight * a_score))
161
+
162
+ for index, element in enumerate(b_list):
163
+ if (element not in a_list):
164
+ b_score = b_scores[index]
165
+ full_dict.append((element, (1-a_weight) * b_score))
166
+
167
+ full_dict = sorted(full_dict, key=lambda item: item[1], reverse=True)
168
+ output_list = [item[0] for item in full_dict]
169
+
170
+ if (top_k >= len(full_dict)):
171
+ return output_list
172
+
173
+ # create final response object
174
+ return output_list[:top_k]
metadata_adder.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ ### DOCUMENT PROCESSOR [Metadata Adders]
3
+ #####################################################
4
+ ### Jonathan Wang
5
+
6
+ # ABOUT:
7
+ # This creates an app to chat with PDFs.
8
+
9
+ # This is the Metadata Adders
10
+ # Which are classes that add metadata fields to documents.
11
+ # This often is used for summaries or keywords.
12
+ #####################################################
13
+ ### TODO Board:
14
+ # Seems like this overlaps well with the `metadata extractors` interface from llama_index.
15
+ # These are TransformComponents which take a Sequence of Nodes as input, and returns a list of Dicts as output (with the dicts storing metdata for each node).
16
+ # We should add a wrapper which adds this metadata to nodes.
17
+ # We should also add a wrapper
18
+
19
+ # https://github.com/run-llama/llama_index/blob/be3bd619ec114d26cf328d12117c033762695b3f/llama-index-core/llama_index/core/extractors/interface.py#L21
20
+ # https://github.com/run-llama/llama_index/blob/be3bd619ec114d26cf328d12117c033762695b3f/llama-index-core/llama_index/core/extractors/metadata_extractors.py#L332
21
+
22
+ #####################################################
23
+ ### PROGRAM SETTINGS
24
+
25
+
26
+ #####################################################
27
+ ### PROGRAM IMPORTS
28
+ from __future__ import annotations
29
+
30
+ import logging
31
+ import re
32
+ from abc import abstractmethod
33
+ from typing import Any, List, Optional, TypeVar, Sequence
34
+
35
+ from llama_index.core.bridge.pydantic import Field, PrivateAttr
36
+ from llama_index.core.schema import BaseNode, TransformComponent
37
+
38
+ # Own modules
39
+
40
+
41
+ #####################################################
42
+ ### CONSTANTS
43
+ # ah how beautiful the regex
44
+ # handy visualizer and checker: https://www.debuggex.com/, https://www.regexpr.com/
45
+ logger = logging.getLogger(__name__)
46
+ GenericNode = TypeVar("GenericNode", bound=BaseNode)
47
+
48
+ DATE_REGEX = re.compile(r"(?:(?<!\:)(?<!\:\d)[0-3]?\d(?:st|nd|rd|th)?\s+(?:of\s+)?(?:jan\.?|january|feb\.?|february|mar\.?|march|apr\.?|april|may|jun\.?|june|jul\.?|july|aug\.?|august|sep\.?|september|oct\.?|october|nov\.?|november|dec\.?|december)|(?:jan\.?|january|feb\.?|february|mar\.?|march|apr\.?|april|may|jun\.?|june|jul\.?|july|aug\.?|august|sep\.?|september|oct\.?|october|nov\.?|november|dec\.?|december)\s+(?<!\:)(?<!\:\d)[0-3]?\d(?:st|nd|rd|th)?)(?:\,)?\s*(?:\d{4})?|[0-3]?\d[-\./][0-3]?\d[-\./]\d{2,4}", re.IGNORECASE)
49
+ TIME_REGEX = re.compile(r"\d{1,2}:\d{2} ?(?:[ap]\.?m\.?)?|\d[ap]\.?m\.?", re.IGNORECASE)
50
+ EMAIL_REGEX = re.compile(r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)")
51
+ PHONE_REGEX = re.compile(r"((?:(?<![\d-])(?:\+?\d{1,3}[-.\s*]?)?(?:\(?\d{3}\)?[-.\s*]?)?\d{3}[-.\s*]?\d{4}(?![\d-]))|(?:(?<![\d-])(?:(?:\(\+?\d{2}\))|(?:\+?\d{2}))\s*\d{2}\s*\d{3}\s*\d{4}(?![\d-])))")
52
+ MAIL_ADDR_REGEX = re.compile(r"\d{1,4}.{1,10}[\w\s]{1,20}[\s]+(?:street|st|avenue|ave|road|rd|highway|hwy|square|sq|trail|trl|drive|dr|court|ct|parkway|pkwy|circle|cir|boulevard|blvd)\W?(?=\s|$)", re.IGNORECASE)
53
+
54
+ # DEFAULT_NUM_WORKERS = os.cpu_count() - 1 if os.cpu_count() else 1 # type: ignore
55
+
56
+
57
+ #####################################################
58
+ ### SCRIPT
59
+
60
+ class MetadataAdder(TransformComponent):
61
+ """Adds metadata to a node.
62
+
63
+ Args:
64
+ metadata_name: The name of the metadata to add to the node. Defaults to 'metadata'.
65
+ # num_workers: The number of workers to use for parallel processing. By default, use all available cores minus one. currently WIP.
66
+ """
67
+
68
+ metadata_name: str = Field(
69
+ default="metadata",
70
+ description="The name of the metadata field to add to the document. Defaults to 'metadata'.",
71
+ )
72
+ # num_workers: int = Field(
73
+ # default=DEFAULT_NUM_WORKERS,
74
+ # description="The number of workers to use for parallel processing. By default, use all available cores minus one.",
75
+ # )
76
+
77
+ def __init__(
78
+ self, metadata_name: str = "metadata", **kwargs: Any
79
+ ) -> None:
80
+ super().__init__(**kwargs)
81
+ self.metadata_name = metadata_name
82
+ # self.num_workers = num_workers
83
+
84
+ @classmethod
85
+ def class_name(cls) -> str:
86
+ return "MetadataAdder"
87
+
88
+ @abstractmethod
89
+ def get_node_metadata(self, node: BaseNode) -> str | None:
90
+ """Given a node, get the metadata for the node."""
91
+
92
+ def add_node_metadata(self, node: GenericNode, metadata_value: Any | None) -> GenericNode:
93
+ """Given a node and the metadata, add the metadata to the node's `metadata_name` field."""
94
+ if (metadata_value is None):
95
+ return node
96
+ else:
97
+ node.metadata[self.metadata_name] = metadata_value
98
+ return node
99
+
100
+ def process_nodes(self, nodes: list[GenericNode]) -> list[GenericNode]:
101
+ """Process the list of nodes. This gets called by __call__.
102
+
103
+ Args:
104
+ nodes (List[GenericNode]): The nodes to process.
105
+
106
+ Returns:
107
+ List[GenericNode]: The processed nodes, with metadata field metadata_name added.
108
+ """
109
+ output_nodes = []
110
+ for node in nodes:
111
+ node_metadata = self.get_node_metadata(node)
112
+ node_with_metadata = self.add_node_metadata(node, node_metadata)
113
+ output_nodes.append(node_with_metadata)
114
+ return(output_nodes)
115
+
116
+ def __call__(self, nodes: Sequence[BaseNode], **kwargs: Any) -> list[BaseNode]:
117
+ """Check whether nodes have the specified regex pattern."""
118
+ return self.process_nodes(nodes)
119
+
120
+
121
+ class RegexMetadataAdder(MetadataAdder):
122
+ """Adds regex metadata to a document.
123
+
124
+ Args:
125
+ regex_pattern: The regex pattern to search for.
126
+ metadata_name: The name of the metadata to add to the document. Defaults to 'regex_metadata'.
127
+ # num_workers: The number of workers to use for parallel processing. By default, use all available cores minus one.
128
+ """
129
+
130
+ _regex_pattern: re.Pattern = PrivateAttr()
131
+ _boolean_mode: bool = PrivateAttr()
132
+ # num_workers: int = Field(
133
+ # default=DEFAULT_NUM_WORKERS,
134
+ # description="The number of workers to use for parallel processing. By default, use all available cores minus one.",
135
+ # )
136
+
137
+ def __init__(
138
+ self,
139
+ regex_pattern: re.Pattern | str = DATE_REGEX,
140
+ metadata_name: str = "regex_metadata",
141
+ boolean_mode: bool = False,
142
+ # num_workers: int = DEFAULT_NUM_WORKERS,
143
+ **kwargs: Any,
144
+ ) -> None:
145
+ """Init params."""
146
+ if (isinstance(regex_pattern, str)):
147
+ regex_pattern = re.compile(regex_pattern)
148
+ # self.num_workers = num_workers
149
+ super().__init__(metadata_name=metadata_name, **kwargs) # ah yes i love oop :)
150
+ self._regex_pattern=regex_pattern
151
+ self._boolean_mode=boolean_mode
152
+
153
+ @classmethod
154
+ def class_name(cls) -> str:
155
+ return "RegexMetadataAdder"
156
+
157
+ def get_node_metadata(self, node: BaseNode) -> str | None:
158
+ """Given a node with text, return the regex match if it exists.
159
+
160
+ Args:
161
+ node (BaseNode): The base node to extract from.
162
+
163
+ Returns:
164
+ Optional[str]: The regex match if it exists. If not, return None.
165
+ """
166
+ if (getattr(node, "text", None) is None):
167
+ return None
168
+
169
+ if (self._boolean_mode):
170
+ return str(self._regex_pattern.match(node.text) is not None)
171
+ else:
172
+ return str(self._regex_pattern.findall(node.text)) # NOTE: we are saving these as a string'd list since this is easier
173
+
174
+
175
+ class ModelMetadataAdder(MetadataAdder):
176
+ """Adds metadata to nodes based on a language model."""
177
+
178
+ prompt_template: str = Field(
179
+ description="The prompt to use to generate the metadata. Defaults to DEFAULT_SUMMARY_TEMPLATE.",
180
+ )
181
+
182
+ def __init__(
183
+ self,
184
+ metadata_name: str,
185
+ prompt_template: str | None = None,
186
+ **kwargs: Any
187
+ ) -> None:
188
+ """Init params."""
189
+ super().__init__(metadata_name=metadata_name, prompt_template=prompt_template, **kwargs)
190
+
191
+ @classmethod
192
+ def class_name(cls) -> str:
193
+ return "ModelMetadataAdder"
194
+
195
+ @abstractmethod
196
+ def get_node_metadata(self, node: BaseNode) -> str | None:
197
+ """Given a node, get the metadata for the node.
198
+
199
+ Args:
200
+ node (BaseNode): The node to add metadata to.
201
+
202
+ Returns:
203
+ Optional[str]: The metadata if it exists. If not, return None.
204
+ """
205
+
206
+
207
+ class UnstructuredPDFPostProcessor(TransformComponent):
208
+ """Handles postprocessing of PDF which was read in using UnstructuredIO."""
209
+
210
+ ### NOTE: okay technically we could have done this in the IngestionPipeline abstraction. Maybe we integrate in the future?
211
+ # This component doesn't play nice with multi-processing due to having non-async LLMs.
212
+
213
+ # _embed_model: Optional[BaseEmbedding] = PrivateAttr()
214
+ _metadata_adders: list[MetadataAdder] = PrivateAttr()
215
+
216
+ def __init__(
217
+ self,
218
+ # embed_model: Optional[BaseEmbedding] = None,
219
+ metadata_adders: list[MetadataAdder] | None = None,
220
+ **kwargs: Any,
221
+ ) -> None:
222
+ super().__init__(**kwargs)
223
+ # self._embed_model = embed_model or Settings.embed_model
224
+ self._metadata_adders = metadata_adders or []
225
+
226
+ @classmethod
227
+ def class_name(cls) -> str:
228
+ return "UnstructuredPDFPostProcessor"
229
+
230
+ # def _apply_embed_model(self, nodes: List[BaseNode]) -> List[BaseNode]:
231
+ # if (self._embed_model is not None):
232
+ # nodes = self._embed_model(nodes)
233
+ # return nodes
234
+
235
+ def _apply_metadata_adders(self, nodes: list[GenericNode]) -> list[GenericNode]:
236
+ for metadata_adder in self._metadata_adders:
237
+ nodes = metadata_adder(nodes)
238
+ return nodes
239
+
240
+ def __call__(self, nodes: list[GenericNode], **kwargs: Any) -> Sequence[BaseNode]:
241
+ return self._apply_metadata_adders(nodes)
242
+ # nodes = self._apply_embed_model(nodes) # this goes second in case we want to embed the metadata.
243
+
244
+ # def has_email(input_text: str) -> bool:
245
+ # """
246
+ # Given a chunk of text, determine whether it has an email address or not.
247
+
248
+ # We're using the long complex email regex from https://emailregex.com/index.html
249
+ # """
250
+ # return (EMAIL_REGEX.search(input_text) is not None)
251
+
252
+
253
+ # def has_phone(input_text: str) -> bool:
254
+ # """
255
+ # Given a chunk of text, determine whether it has a phone number or not.
256
+ # """
257
+ # has_phone = PHONE_REGEX.search(input_text)
258
+ # return (has_phone is not None)
259
+
260
+
261
+ # def has_mail_addr(input_text: str) -> bool:
262
+ # """
263
+ # Given a chunk of text, determine whether it has a mailing address or not.
264
+
265
+ # NOTE: This is difficult to do with regex.
266
+ # ... We could use spacy's English language NER model instead / as well:
267
+ # Assume that addresses will have a GSP (geospatial political) or GPE (geopolitical entity).
268
+ # DOCS SEE: https://www.nltk.org/book/ch07.html | https://spacy.io/usage/linguistic-features
269
+ # """
270
+ # has_addr = MAIL_ADDR_REGEX.search(input_text)
271
+ # return (has_addr is not None)
272
+
273
+
274
+ # def has_date(input_text: str) -> bool:
275
+ # """
276
+ # Given a chunk of text, determine whether it has a date or not.
277
+ # NOTE: relative dates are stuff like "within 30 days"
278
+ # """
279
+ # has_date = DATE_REGEX.search(input_text)
280
+ # return (has_date is not None)
models.py ADDED
@@ -0,0 +1,785 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ ### DOCUMENT PROCESSOR [MODELS]
3
+ #####################################################
4
+ # Jonathan Wang
5
+
6
+ # ABOUT:
7
+ # This project creates an app to chat with PDFs.
8
+
9
+ # This is the LANGUAGE MODELS
10
+ # that are used in the document reader.
11
+ #####################################################
12
+ ## TODOS:
13
+ # <!> Add support for vLLM / AWQ / GPTQ models. (probably not going to be done due to lack of attention scores)
14
+
15
+ # Add KTransformers backend?
16
+ # https://github.com/kvcache-ai/ktransformers
17
+
18
+ # https://github.com/Tada-AI/pdf_parser
19
+
20
+ #####################################################
21
+ ## IMPORTS:
22
+ from __future__ import annotations
23
+
24
+ import gc
25
+ import logging
26
+ import sys
27
+ from typing import (
28
+ Any,
29
+ Callable,
30
+ Dict,
31
+ List,
32
+ Optional,
33
+ Protocol,
34
+ Sequence,
35
+ Union,
36
+ cast,
37
+ runtime_checkable,
38
+ )
39
+
40
+ import streamlit as st
41
+ import torch
42
+ from llama_index.core.base.embeddings.base import BaseEmbedding
43
+ from llama_index.core.base.llms.base import BaseLLM
44
+ from llama_index.core.base.llms.generic_utils import (
45
+ messages_to_prompt as generic_messages_to_prompt,
46
+ )
47
+ from llama_index.core.base.llms.types import (
48
+ ChatMessage,
49
+ ChatResponse,
50
+ ChatResponseGen,
51
+ CompletionResponse,
52
+ CompletionResponseGen,
53
+ LLMMetadata,
54
+ MessageRole,
55
+ )
56
+ from llama_index.core.bridge.pydantic import Field, PrivateAttr
57
+ from llama_index.core.callbacks import CallbackManager
58
+ from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
59
+ from llama_index.core.llms.callbacks import (
60
+ llm_chat_callback,
61
+ llm_completion_callback,
62
+ )
63
+ from llama_index.core.multi_modal_llms import MultiModalLLM
64
+ from llama_index.core.postprocessor import SentenceTransformerRerank
65
+ from llama_index.core.prompts.base import PromptTemplate
66
+ from llama_index.core.schema import ImageDocument, ImageNode
67
+ from llama_index.core.types import BaseOutputParser, PydanticProgramMode
68
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
69
+ from llama_index.llms.huggingface import HuggingFaceLLM
70
+ from PIL import Image as PILImage
71
+ from transformers import (
72
+ AutoImageProcessor,
73
+ AutoModelForVision2Seq,
74
+ AutoTokenizer,
75
+ LogitsProcessor,
76
+ QuantoConfig,
77
+ StoppingCriteria,
78
+ StoppingCriteriaList,
79
+ )
80
+ from typing_extensions import Annotated
81
+
82
+ # from wtpsplit import SaT # Sentence segmentation model. Dropping this. Requires adapters=0.2.1->Transformers=4.39.3 | Phi3 Vision requires Transformers 4.40.2
83
+
84
+ ## NOTE: Proposal for LAZY LOADING packages for running LLMS:
85
+ # Currently not done because empahsis is on local inference w/ ability to get Attention Scores, which is not yet supported in non-HF Transformers methods.
86
+
87
+ ## LLamacpp:
88
+ # from llama_index.llms.llama_cpp import LlamaCPP
89
+ # from llama_index.llms.llama_cpp.llama_utils import (
90
+ # messages_to_prompt,
91
+ # completion_to_prompt
92
+ # )
93
+
94
+ ## HF Transformers LLM:
95
+ # from transformers import AutoTokenizer, BitsAndBytesConfig
96
+ # from llama_index.llms.huggingface import HuggingFaceLLM
97
+
98
+ ## GROQ
99
+ # from llama_index.llms.groq import Groq
100
+
101
+ #####################################################
102
+ ### SETTINGS:
103
+ DEFAULT_HF_MULTIMODAL_LLM = "Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5"
104
+ DEFAULT_HF_MULTIMODAL_CONTEXT_WINDOW = 1024
105
+ DEFAULT_HF_MULTIMODAL_MAX_NEW_TOKENS = 1024
106
+
107
+ #####################################################
108
+ ### CODE:
109
+ logger = logging.getLogger(__name__)
110
+
111
+ @st.cache_resource
112
+ def get_embedder(
113
+ model_path: str = "mixedbread-ai/mxbai-embed-large-v1",
114
+ device: str = "cuda", # 'cpu' is unbearably slow
115
+ ) -> BaseEmbedding:
116
+ """Given the path to an embedding model, load it."""
117
+ # NOTE: okay we definitely could have not made this wrapper, but shrug
118
+ return HuggingFaceEmbedding(
119
+ model_path,
120
+ device=device
121
+ )
122
+
123
+
124
+ @st.cache_resource
125
+ def get_reranker(
126
+ model_path: str = "mixedbread-ai/mxbai-rerank-large-v1",
127
+ top_n: int = 3,
128
+ device: str = "cpu", # 'cuda' if we were rich
129
+ ) -> SentenceTransformerRerank: # technically this is a BaseNodePostprocessor, but that seems too abstract.
130
+ """Given the path to a reranking model, load it."""
131
+ # NOTE: okay we definitely could have not made this wrapper, but shrug
132
+ return SentenceTransformerRerank(
133
+ model=model_path,
134
+ top_n=top_n,
135
+ device=device
136
+ )
137
+
138
+
139
+ ## LLM Options Below
140
+ # def _get_llamacpp_llm(
141
+ # model_path: str,
142
+ # model_seed: int = 31415926,
143
+ # model_temperature: float = 1e-64, # ideally 0, but HF-type doesn't allow that. # a good dev might use sys.float_info()['min']
144
+ # model_context_length: Optional[int] = 8192,
145
+ # model_max_new_tokens: Optional[int] = 1024,
146
+ # ) -> BaseLLM:
147
+ # """Load a LlamaCPP model using GPU and other sane defaults."""
148
+ # # Lazy Loading
149
+ # from llama_index.llms.llama_cpp import LlamaCPP
150
+ # from llama_index.llms.llama_cpp.llama_utils import (
151
+ # messages_to_prompt,
152
+ # completion_to_prompt
153
+ # )
154
+
155
+ # # Arguments to Pass
156
+ # llm = LlamaCPP(
157
+ # model_path=model_path,
158
+ # temperature=model_temperature,
159
+ # max_new_tokens=model_max_new_tokens,
160
+ # context_window=model_context_length,
161
+ # # kwargs to pass to __call__()
162
+ # generate_kwargs={'seed': model_seed}, # {'temperature': TEMPERATURE, 'top_p':0.7, 'min_p':0.1, 'seed': MODEL_SEED},
163
+ # # kwargs to pass to __init__()
164
+ # # set to at least 1 to use GPU
165
+ # model_kwargs={'n_gpu_layers': -1, 'n_threads': os.cpu_count()-1}, #, 'rope_freq_scale': 0.83, 'rope_freq_base': 20000},
166
+ # # transform inputs into model format
167
+ # messages_to_prompt=messages_to_prompt,
168
+ # completion_to_prompt=completion_to_prompt,
169
+ # verbose=True,
170
+ # )
171
+ # return (llm)
172
+
173
+
174
+ @st.cache_resource
175
+ def _get_hf_llm(
176
+ model_path: str,
177
+ model_temperature: float = sys.float_info.min, # ideally 0, but HF-type doesn't allow that. # a good dev might use sys.float_info()['min'] to confirm (?)
178
+ model_context_length: int | None = 16384,
179
+ model_max_new_tokens: int | None = 2048,
180
+ hf_quant_level: int | None = 8,
181
+ ) -> BaseLLM:
182
+ """Load a Huggingface-Transformers based model using sane defaults."""
183
+ # Fix temperature if needed; HF implementation complains about it being zero
184
+ model_temperature = max(sys.float_info.min, model_temperature)
185
+
186
+ # Get Quantization with BitsandBytes
187
+ quanto_config = None # NOTE: by default, no quantization.
188
+ if (hf_quant_level == 4):
189
+ # bnb_config = BitsAndBytesConfig(
190
+ # # load_in_8bit=True,
191
+ # load_in_4bit=True,
192
+ # # bnb_4bit_use_double_quant=True,
193
+ # bnb_4bit_quant_type="nf4",
194
+ # bnb_4bit_compute_dtype='bfloat16', # NOTE: Tesla T4 GPUs are too crappy for bfloat16
195
+ # # bnb_4bit_compute_dtype='float16'
196
+ # )
197
+ quanto_config = QuantoConfig(
198
+ weights="int4" # there's also 'int2' if you're crazy...
199
+ )
200
+ elif (hf_quant_level == 8):
201
+ # bnb_config = BitsAndBytesConfig(
202
+ # load_in_8bit=True
203
+ # )
204
+ quanto_config = QuantoConfig(
205
+ weights="int8"
206
+ )
207
+
208
+ # Get Stopping Tokens for Llama3 based models, because they're /special/ and added a new one.
209
+ tokenizer = AutoTokenizer.from_pretrained(
210
+ model_path
211
+ )
212
+ stopping_ids = [
213
+ tokenizer.eos_token_id,
214
+ tokenizer.convert_tokens_to_ids("<|eot_id|>"),
215
+ ]
216
+ return HuggingFaceLLM(
217
+ model_name=model_path,
218
+ tokenizer_name=model_path,
219
+ stopping_ids=stopping_ids,
220
+ max_new_tokens=model_max_new_tokens or DEFAULT_NUM_OUTPUTS,
221
+ context_window=model_context_length or DEFAULT_CONTEXT_WINDOW,
222
+ tokenizer_kwargs={"trust_remote_code": True},
223
+ model_kwargs={"trust_remote_code": True, "quantization_config": quanto_config},
224
+ generate_kwargs={
225
+ "do_sample": not model_temperature > sys.float_info.min,
226
+ "temperature": model_temperature,
227
+ },
228
+ is_chat_model=True,
229
+ )
230
+
231
+
232
+ @st.cache_resource
233
+ def get_llm(
234
+ model_path: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
235
+ model_temperature: float = 0, # ideally 0, but HF-type doesn't allow that. # a good dev might use sys.float_info()['min']
236
+ model_context_length: int | None = 8192,
237
+ model_max_new_tokens: int | None = 1024,
238
+
239
+ hf_quant_level: int | None = 8, # 4-bit / 8-bit loading for HF models
240
+ ) -> BaseLLM:
241
+ """
242
+ Given the path to a LLM, determine the type, load it in and convert it into a Llamaindex-compatable LLM.
243
+
244
+ NOTE: I chose to set some "sane" defaults, so it's probably not as flexible as some other dev would like.
245
+ """
246
+ # if (model_path_extension == ".gguf"):
247
+ # ##### LLAMA.CPP
248
+ # return(_get_llamacpp_llm(model_path, model_seed, model_temperature, model_context_length, model_max_new_tokens))
249
+
250
+ # TODO(Jonathan Wang): Consider non-HF-Transformers backends
251
+ # vLLM support for AWQ/GPTQ models
252
+ # I guess reluctantly AutoAWQ and AutoGPTQ packages.
253
+ # Exllamav2 is kinda dead IMO.
254
+
255
+ # else:
256
+ #### No extension or weird fake extension suggests a folder, i.e., the base model from HF
257
+ return(_get_hf_llm(model_path=model_path, model_temperature=model_temperature, model_context_length=model_context_length, model_max_new_tokens=model_max_new_tokens, hf_quant_level=hf_quant_level))
258
+
259
+
260
+ # @st.cache_resource
261
+ # def get_llm() -> BaseLLM:
262
+ # from llama_index.llms.groq import Groq
263
+
264
+ # llm = Groq(
265
+ # model='llama-3.1-8b-instant', # old: 'llama3-8b-8192'
266
+ # api_key=os.environ.get('GROQ_API_KEY'),
267
+ # )
268
+ # return (llm)
269
+
270
+
271
+ class EosLogitProcessor(LogitsProcessor):
272
+ """Special snowflake processor for Salesforce Vision Model."""
273
+ def __init__(self, eos_token_id: int, end_token_id: int):
274
+ super().__init__()
275
+ self.eos_token_id = eos_token_id
276
+ self.end_token_id = end_token_id
277
+
278
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
279
+ if input_ids.size(1) > 1: # Expect at least 1 output token.
280
+ forced_eos = torch.full((scores.size(1),), -float("inf"), device=input_ids.device)
281
+ forced_eos[self.eos_token_id] = 0
282
+
283
+ # Force generation of EOS after the <|end|> token.
284
+ scores[input_ids[:, -1] == self.end_token_id] = forced_eos
285
+ return scores
286
+
287
+ # NOTE: These two protocols are needed to appease mypy
288
+ # https://github.com/run-llama/llama_index/blob/5238b04c183119b3035b84e2663db115e63dcfda/llama-index-core/llama_index/core/llms/llm.py#L89
289
+ @runtime_checkable
290
+ class MessagesImagesToPromptType(Protocol):
291
+ def __call__(self, messages: Sequence[ChatMessage], images: Sequence[ImageDocument], **kwargs: Any) -> str:
292
+ pass
293
+
294
+ MessagesImagesToPromptCallable = Annotated[
295
+ Optional[MessagesImagesToPromptType],
296
+ WithJsonSchema({"type": "string"}),
297
+ ]
298
+
299
+
300
+ # https://huggingface.co/Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5/blob/main/batch_inference.ipynb
301
+
302
+ class HuggingFaceMultiModalLLM(MultiModalLLM):
303
+ """Supposed to be a wrapper around HuggingFace's Vision LLMS.
304
+ Currently only supports one model type: Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5
305
+ """
306
+
307
+ model_name: str = Field(
308
+ description='The multi-modal huggingface LLM to use. Currently only using Phi3.',
309
+ default=DEFAULT_HF_MULTIMODAL_LLM
310
+ )
311
+ context_window: int = Field(
312
+ default=DEFAULT_HF_MULTIMODAL_CONTEXT_WINDOW,
313
+ description="The maximum number of tokens available for input.",
314
+ gt=0,
315
+ )
316
+ max_new_tokens: int = Field(
317
+ default=DEFAULT_HF_MULTIMODAL_MAX_NEW_TOKENS,
318
+ description="The maximum number of tokens to generate.",
319
+ gt=0,
320
+ )
321
+ system_prompt: str = Field(
322
+ default="",
323
+ description=(
324
+ "The system prompt, containing any extra instructions or context. "
325
+ "The model card on HuggingFace should specify if this is needed."
326
+ ),
327
+ )
328
+ query_wrapper_prompt: PromptTemplate = Field(
329
+ default=PromptTemplate("{query_str}"),
330
+ description=(
331
+ "The query wrapper prompt, containing the query placeholder. "
332
+ "The model card on HuggingFace should specify if this is needed. "
333
+ "Should contain a `{query_str}` placeholder."
334
+ ),
335
+ )
336
+ tokenizer_name: str = Field(
337
+ default=DEFAULT_HF_MULTIMODAL_LLM,
338
+ description=(
339
+ "The name of the tokenizer to use from HuggingFace. "
340
+ "Unused if `tokenizer` is passed in directly."
341
+ ),
342
+ )
343
+ processor_name: str = Field(
344
+ default=DEFAULT_HF_MULTIMODAL_LLM,
345
+ description=(
346
+ "The name of the processor to use from HuggingFace. "
347
+ "Unused if `processor` is passed in directly."
348
+ ),
349
+ )
350
+ device_map: str = Field(
351
+ default="auto", description="The device_map to use. Defaults to 'auto'."
352
+ )
353
+ stopping_ids: list[int] = Field(
354
+ default_factory=list,
355
+ description=(
356
+ "The stopping ids to use. "
357
+ "Generation stops when these token IDs are predicted."
358
+ ),
359
+ )
360
+ tokenizer_outputs_to_remove: list = Field(
361
+ default_factory=list,
362
+ description=(
363
+ "The outputs to remove from the tokenizer. "
364
+ "Sometimes huggingface tokenizers return extra inputs that cause errors."
365
+ ),
366
+ )
367
+ tokenizer_kwargs: dict = Field(
368
+ default_factory=dict, description="The kwargs to pass to the tokenizer."
369
+ )
370
+ processor_kwargs: dict = Field(
371
+ default_factory=dict, description="The kwargs to pass to the processor."
372
+ )
373
+ model_kwargs: dict = Field(
374
+ default_factory=dict,
375
+ description="The kwargs to pass to the model during initialization.",
376
+ )
377
+ generate_kwargs: dict = Field(
378
+ default_factory=dict,
379
+ description="The kwargs to pass to the model during generation.",
380
+ )
381
+ is_chat_model: bool = Field(
382
+ default=False,
383
+ description=(
384
+ "Whether the model can have multiple messages passed at once, like the OpenAI chat API."
385
+ # LLMMetadata.__fields__["is_chat_model"].field_info.description
386
+ # + " Be sure to verify that you either pass an appropriate tokenizer "
387
+ # "that can convert prompts to properly formatted chat messages or a "
388
+ # "`messages_to_prompt` that does so."
389
+ ),
390
+ )
391
+ messages_images_to_prompt: MessagesImagesToPromptCallable = Field(
392
+ default=generic_messages_to_prompt,
393
+ description="A function that takes in a list of messages and images and returns a prompt string.",
394
+ )
395
+
396
+ _model: Any = PrivateAttr()
397
+ _tokenizer: Any = PrivateAttr()
398
+ # TODO(Jonathan Wang): We need to add a separate field for AutoProcessor as opposed to ImageProcessors.
399
+ _processor: Any = PrivateAttr()
400
+ _stopping_criteria: Any = PrivateAttr()
401
+
402
+ def __init__(
403
+ self,
404
+ context_window: int = DEFAULT_HF_MULTIMODAL_CONTEXT_WINDOW,
405
+ max_new_tokens: int = DEFAULT_HF_MULTIMODAL_MAX_NEW_TOKENS,
406
+ query_wrapper_prompt: Union[str, PromptTemplate] = "{query_str}",
407
+ tokenizer_name: str = DEFAULT_HF_MULTIMODAL_LLM,
408
+ processor_name: str = DEFAULT_HF_MULTIMODAL_LLM,
409
+ model_name: str = DEFAULT_HF_MULTIMODAL_LLM,
410
+ model: Any | None = None,
411
+ tokenizer: Any | None = None,
412
+ processor: Any | None = None,
413
+ device_map: str = "auto",
414
+ stopping_ids: list[int] | None = None,
415
+ tokenizer_kwargs: dict[str, Any] | None = None,
416
+ processor_kwargs: dict[str, Any] | None = None,
417
+ tokenizer_outputs_to_remove: list[str] | None = None,
418
+ model_kwargs: dict[str, Any] | None = None,
419
+ generate_kwargs: dict[str, Any] | None = None,
420
+ is_chat_model: bool = False,
421
+ callback_manager: CallbackManager | None = None,
422
+ system_prompt: str = "",
423
+ messages_images_to_prompt: Callable[[Sequence[ChatMessage], Sequence[ImageDocument]], str] | None = None,
424
+ # completion_to_prompt: Callable[[str], str] | None = None,
425
+ # pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
426
+ # output_parser: BaseOutputParser | None = None,
427
+ ) -> None:
428
+
429
+ logger.info(f"CUDA Memory Pre-AutoModelForVision2Seq: {torch.cuda.mem_get_info()}")
430
+ # Salesforce one is a AutoModelForVision2Seq, but not AutoCausalLM which is more common.
431
+ model = model or AutoModelForVision2Seq.from_pretrained(
432
+ model_name,
433
+ device_map=device_map,
434
+ trust_remote_code=True,
435
+ **(model_kwargs or {})
436
+ )
437
+ logger.info(f"CUDA Memory Post-AutoModelForVision2Seq: {torch.cuda.mem_get_info()}")
438
+
439
+ # check context_window
440
+ config_dict = model.config.to_dict()
441
+ model_context_window = int(
442
+ config_dict.get("max_position_embeddings", context_window)
443
+ )
444
+ if model_context_window < context_window:
445
+ logger.warning(
446
+ f"Supplied context_window {context_window} is greater "
447
+ f"than the model's max input size {model_context_window}. "
448
+ "Disable this warning by setting a lower context_window."
449
+ )
450
+ context_window = model_context_window
451
+
452
+ processor_kwargs = processor_kwargs or {}
453
+ if "max_length" not in processor_kwargs:
454
+ processor_kwargs["max_length"] = context_window
455
+
456
+ # NOTE: Sometimes models (phi-3) will use AutoProcessor and include the tokenizer within it.
457
+ logger.info(f"CUDA Memory Pre-Processor: {torch.cuda.mem_get_info()}")
458
+ processor = processor or AutoImageProcessor.from_pretrained(
459
+ processor_name or model_name,
460
+ trust_remote_code=True,
461
+ **processor_kwargs
462
+ )
463
+ logger.info(f"CUDA Memory Post-Processor: {torch.cuda.mem_get_info()}")
464
+
465
+ tokenizer = tokenizer or AutoTokenizer.from_pretrained(
466
+ tokenizer_name or model_name,
467
+ trust_remote_code=True,
468
+ **(tokenizer_kwargs or {})
469
+ )
470
+ logger.info(f"CUDA Memory Post-Tokenizer: {torch.cuda.mem_get_info()}")
471
+
472
+ # Tokenizer-Model disagreement
473
+ if (hasattr(tokenizer, "name_or_path") and tokenizer.name_or_path != model_name): # type: ignore (checked for attribute)
474
+ logger.warning(
475
+ f"The model `{model_name}` and processor `{getattr(tokenizer, 'name_or_path', None)}` "
476
+ f"are different, please ensure that they are compatible."
477
+ )
478
+ # Processor-Model disagreement
479
+ if (hasattr(processor, "name_or_path") and getattr(processor, "name_or_path", None) != model_name):
480
+ logger.warning(
481
+ f"The model `{model_name}` and processor `{getattr(processor, 'name_or_path', None)}` "
482
+ f"are different, please ensure that they are compatible."
483
+ )
484
+
485
+ # setup stopping criteria
486
+ stopping_ids_list = stopping_ids or []
487
+
488
+ class StopOnTokens(StoppingCriteria):
489
+ def __call__(
490
+ self,
491
+ input_ids: torch.LongTensor,
492
+ scores: torch.FloatTensor,
493
+ **kwargs: Any,
494
+ ) -> bool:
495
+ return any(input_ids[0][-1] == stop_id for stop_id in stopping_ids_list)
496
+
497
+ stopping_criteria = StoppingCriteriaList([StopOnTokens()])
498
+
499
+ if isinstance(query_wrapper_prompt, str):
500
+ query_wrapper_prompt = PromptTemplate(query_wrapper_prompt)
501
+
502
+ messages_images_to_prompt = messages_images_to_prompt or self._processor_messages_to_prompt
503
+
504
+ # Initiate standard LLM
505
+ super().__init__(
506
+ callback_manager=callback_manager or CallbackManager([]),
507
+ )
508
+ logger.info(f"CUDA Memory Post-SuperInit: {torch.cuda.mem_get_info()}")
509
+
510
+ # Initiate remaining fields
511
+ self._model = model
512
+ self._tokenizer = tokenizer
513
+ self._processor = processor
514
+ logger.info(f"CUDA Memory Post-Init: {torch.cuda.mem_get_info()}")
515
+ self._stopping_criteria = stopping_criteria
516
+ self.model_name = model_name
517
+ self.context_window=context_window
518
+ self.max_new_tokens=max_new_tokens
519
+ self.system_prompt=system_prompt
520
+ self.query_wrapper_prompt=query_wrapper_prompt
521
+ self.tokenizer_name=tokenizer_name
522
+ self.processor_name=processor_name
523
+ self.model_name=model_name
524
+ self.device_map=device_map
525
+ self.stopping_ids=stopping_ids or []
526
+ self.tokenizer_outputs_to_remove=tokenizer_outputs_to_remove or []
527
+ self.tokenizer_kwargs=tokenizer_kwargs or {}
528
+ self.processor_kwargs=processor_kwargs or {}
529
+ self.model_kwargs=model_kwargs or {}
530
+ self.generate_kwargs=generate_kwargs or {}
531
+ self.is_chat_model=is_chat_model
532
+ self.messages_images_to_prompt=messages_images_to_prompt
533
+ # self.completion_to_prompt=completion_to_prompt,
534
+ # self.pydantic_program_mode=pydantic_program_mode,
535
+ # self.output_parser=output_parser,
536
+
537
+ @classmethod
538
+ def class_name(cls) -> str:
539
+ return "HuggingFace_MultiModal_LLM"
540
+
541
+ @property
542
+ def metadata(self) -> LLMMetadata:
543
+ """LLM metadata."""
544
+ return LLMMetadata(
545
+ context_window=self.context_window,
546
+ num_output=self.max_new_tokens,
547
+ model_name=self.model_name,
548
+ is_chat_model=self.is_chat_model,
549
+ )
550
+
551
+ def _processor_messages_to_prompt(self, messages: Sequence[ChatMessage], images: Sequence[ImageDocument]) -> str:
552
+ ### TODO(Jonathan Wang): Make this work generically. Currently we're building for `Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5`
553
+ """Converts a list of messages into a prompt for the multimodal LLM.
554
+ NOTE: we assume for simplicity here that these images are related, and not the user bouncing between multiple different topics. Thus, we send them all at once.
555
+
556
+ Args:
557
+ messages (Sequence[ChatMessage]): A list of the messages to convert, where each message is a dict containing the message role and content.
558
+ images (Sequence[ImageDocument]): The number of images the user is passing to the MultiModalLLM.
559
+ Returns:
560
+ str: The prompt.
561
+ """
562
+ # NOTE: For `Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5`, we actually ignore the `images`; no plaaceholders.
563
+
564
+ """Use the tokenizer to convert messages to prompt. Fallback to generic."""
565
+ if hasattr(self._tokenizer, "apply_chat_template"):
566
+ messages_dict = [
567
+ {"role": message.role.value, "content": message.content}
568
+ for message in messages
569
+ ]
570
+ return self._tokenizer.apply_chat_template(
571
+ messages_dict, tokenize=False, add_generation_prompt=True
572
+ )
573
+
574
+ return generic_messages_to_prompt(messages)
575
+
576
+ @llm_completion_callback()
577
+ def complete(
578
+ self,
579
+ prompt: str,
580
+ image_documents: ImageNode | List[ImageNode] | ImageDocument | List[ImageDocument], # this also takes ImageDocument which inherits from ImageNode.
581
+ formatted: bool = False,
582
+ **kwargs: Any
583
+ ) -> CompletionResponse:
584
+ """Given a prompt and image node(s), get the Phi-3 Vision prompt"""
585
+ # Handle images input
586
+ # https://huggingface.co/Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5/blob/main/demo.ipynb
587
+ batch_image_list = []
588
+ batch_image_sizes = []
589
+ batch_prompt = []
590
+
591
+ # Fix image_documents input typing
592
+ if (not isinstance(image_documents, list)):
593
+ image_documents = [image_documents]
594
+ image_documents = [cast(ImageDocument, image) for image in image_documents] # we probably won't be using the Document features, so I think this is fine.
595
+
596
+ # Convert input images into PIL images for the model.
597
+ image_list = []
598
+ image_sizes = []
599
+ for image in image_documents:
600
+ # NOTE: ImageDocument inherets from ImageNode. We'll go extract the image.
601
+ image_io = image.resolve_image()
602
+ image_pil = PILImage.open(image_io)
603
+ image_list.append(self._processor([image_pil], image_aspect_ratio='anyres')['pixel_values'].to(self._model.device))
604
+ image_sizes.append(image_pil.size)
605
+
606
+ batch_image_list.append(image_list)
607
+ batch_image_sizes.append(image_sizes)
608
+ batch_prompt.append(prompt) # only one question per image
609
+
610
+ # Get the prompt
611
+ if not formatted and self.query_wrapper_prompt:
612
+ prompt = self.query_wrapper_prompt.format(query_str=prompt)
613
+
614
+ prompt_sequence = []
615
+ if self.system_prompt:
616
+ prompt_sequence.append(ChatMessage(role=MessageRole.SYSTEM, content=self.system_prompt))
617
+ prompt_sequence.append(ChatMessage(role=MessageRole.USER, content=prompt))
618
+
619
+ prompt = self.messages_images_to_prompt(messages=prompt_sequence, images=image_documents)
620
+
621
+ # Get the model input
622
+ batch_inputs = {
623
+ "pixel_values": batch_image_list
624
+ }
625
+ language_inputs = self._tokenizer(
626
+ [prompt],
627
+ return_tensors="pt",
628
+ padding='longest', # probably not needed.
629
+ max_length=self._tokenizer.model_max_length,
630
+ truncation=True
631
+ ).to(self._model.device)
632
+ # TODO: why does the example cookbook have this weird conversion to Cuda instead of .to(device)?
633
+ # language_inputs = {name: tensor.cuda() for name, tensor in language_inputs.items()}
634
+ batch_inputs.update(language_inputs)
635
+
636
+ gc.collect()
637
+ torch.cuda.empty_cache()
638
+
639
+ # remove keys from the tokenizer if needed, to avoid HF errors
640
+ # TODO: this probably is broken and wouldn't work.
641
+ for key in self.tokenizer_outputs_to_remove:
642
+ if key in batch_inputs:
643
+ batch_inputs.pop(key, None)
644
+
645
+ # Get output
646
+ tokens = self._model.generate(
647
+ **batch_inputs,
648
+ image_sizes=batch_image_sizes,
649
+ pad_token_id=self._tokenizer.pad_token_id,
650
+ eos_token_id=self._tokenizer.eos_token_id,
651
+ max_new_tokens=self.max_new_tokens,
652
+ stopping_criteria=self._stopping_criteria,
653
+ # NOTE: Special snowflake processor for Salesforce XGEN Phi3 Mini.
654
+ logits_processor=[EosLogitProcessor(eos_token_id=self._tokenizer.eos_token_id, end_token_id=32007)],
655
+ **self.generate_kwargs
656
+ )
657
+ gc.collect()
658
+ torch.cuda.empty_cache()
659
+
660
+ # completion_tokens = tokens[:, batch_inputs['input_ids'].shape[1]:]
661
+ completion = self._tokenizer.batch_decode(
662
+ tokens,
663
+ skip_special_tokens=True,
664
+ clean_up_tokenization_spaces=False
665
+ )[0]
666
+ gc.collect()
667
+ torch.cuda.empty_cache()
668
+
669
+ output = CompletionResponse(text=completion, raw={'model_output': tokens})
670
+
671
+ # Clean stuff up
672
+ del batch_image_list, batch_image_sizes, batch_inputs, tokens, completion
673
+ gc.collect()
674
+ torch.cuda.empty_cache()
675
+
676
+ # Return the completion
677
+ return output
678
+
679
+ @llm_completion_callback()
680
+ def stream_complete(
681
+ self, prompt: str, formatted: bool = False, **kwargs: Any
682
+ ) -> CompletionResponseGen:
683
+ raise NotImplementedError
684
+
685
+ @llm_chat_callback()
686
+ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
687
+ raise NotImplementedError
688
+
689
+ @llm_chat_callback()
690
+ def stream_chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponseGen:
691
+ raise NotImplementedError
692
+
693
+ @llm_completion_callback()
694
+ async def acomplete(
695
+ self,
696
+ prompt: str,
697
+ images: ImageNode | List[ImageNode], # this also takes ImageDocument which inherits from ImageNode.
698
+ formatted: bool = False,
699
+ **kwargs: Any
700
+ ) -> CompletionResponse:
701
+ raise NotImplementedError
702
+
703
+ @llm_completion_callback()
704
+ async def astream_complete(
705
+ self, prompt: str, formatted: bool = False, **kwargs: Any
706
+ ) -> CompletionResponseGen:
707
+ raise NotImplementedError
708
+
709
+ @llm_chat_callback()
710
+ async def achat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
711
+ raise NotImplementedError
712
+
713
+ @llm_chat_callback()
714
+ async def astream_chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponseGen:
715
+ raise NotImplementedError
716
+
717
+
718
+ # @st.cache_resource()
719
+ # def get_multimodal_llm(**kwargs) -> MultiModalLLM:
720
+ # vision_llm = OpenAIMultiModal(
721
+ # model='gpt-4o-mini',
722
+ # temperature=0,
723
+ # max_new_tokens=512,
724
+ # image_detail='auto'
725
+ # )
726
+ # return (vision_llm)
727
+
728
+ @st.cache_resource
729
+ def get_multimodal_llm(
730
+ model_name: str = DEFAULT_HF_MULTIMODAL_LLM,
731
+ device_map: str = "cuda", # does not support 'auto'
732
+ processor_kwargs: dict[str, Any] | None = None,
733
+ model_kwargs: dict[str, Any] | None = None, # {'torch_dtype': torch.bfloat16}, # {'torch_dtype': torch.float8_e5m2}
734
+ generate_kwargs: dict[str, Any] | None = None, # from the example cookbook
735
+
736
+ hf_quant_level: int | None = 8,
737
+ ) -> HuggingFaceMultiModalLLM:
738
+
739
+ # Get default generate kwargs
740
+ if model_kwargs is None:
741
+ model_kwargs = {}
742
+ if processor_kwargs is None:
743
+ processor_kwargs = {}
744
+ if generate_kwargs is None:
745
+ generate_kwargs = {
746
+ "temperature": sys.float_info.min,
747
+ "top_p": None,
748
+ "num_beams": 1
749
+ # NOTE: we hack in EOSLogitProcessor in the HuggingFaceMultiModalLLM because it allows us to get the tokenizer.eos_token_id
750
+ }
751
+
752
+ # Get Quantization with Quanto
753
+ quanto_config = None # NOTE: by default, no quantization.
754
+ if (hf_quant_level == 4):
755
+ # bnb_config = BitsAndBytesConfig(
756
+ # # load_in_8bit=True,
757
+ # load_in_4bit=True,
758
+ # # bnb_4bit_use_double_quant=True,
759
+ # bnb_4bit_quant_type="nf4",
760
+ # bnb_4bit_compute_dtype='bfloat16', # NOTE: Tesla T4 GPUs are too crappy for bfloat16
761
+ # # bnb_4bit_compute_dtype='float16'
762
+ # )
763
+ quanto_config = QuantoConfig(
764
+ weights="int4" # there's also 'int2' if you're crazy...
765
+ )
766
+ elif (hf_quant_level == 8):
767
+ # bnb_config = BitsAndBytesConfig(
768
+ # load_in_8bit=True
769
+ # )
770
+ quanto_config = QuantoConfig(
771
+ weights="int8"
772
+ )
773
+
774
+ if (quanto_config is not None):
775
+ model_kwargs["quantization_config"] = quanto_config
776
+
777
+ return HuggingFaceMultiModalLLM(
778
+ model_name=model_name,
779
+ device_map=device_map,
780
+ processor_kwargs=processor_kwargs,
781
+ model_kwargs=model_kwargs,
782
+ generate_kwargs=generate_kwargs,
783
+
784
+ max_new_tokens=1024 # from the example cookbook
785
+ )
obs_logging.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ ### DOCUMENT PROCESSOR [OBSERVATION/LOGGING]
3
+ #####################################################
4
+ # Jonathan Wang
5
+
6
+ # ABOUT:
7
+ # This project creates an app to chat with PDFs.
8
+
9
+ # This is the Observation and Logging
10
+ # to see the actions undertaken in the RAG pipeline.
11
+ #####################################################
12
+ ## TODOS:
13
+ # Why does FullRAGEventHandler keep producing duplicate output?
14
+
15
+ #####################################################
16
+ ## IMPORTS:
17
+ from __future__ import annotations
18
+
19
+ import logging
20
+ from typing import TYPE_CHECKING, Any, ClassVar, Sequence
21
+
22
+ import streamlit as st
23
+
24
+ # Callbacks
25
+ from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler
26
+
27
+ # Pretty Printing
28
+ # from llama_index.core.response.notebook_utils import display_source_node
29
+ # End user handler
30
+ from llama_index.core.instrumentation import get_dispatcher
31
+ from llama_index.core.instrumentation.event_handlers import BaseEventHandler
32
+ from llama_index.core.instrumentation.events.agent import (
33
+ AgentChatWithStepEndEvent,
34
+ AgentChatWithStepStartEvent,
35
+ AgentRunStepEndEvent,
36
+ AgentRunStepStartEvent,
37
+ AgentToolCallEvent,
38
+ )
39
+ from llama_index.core.instrumentation.events.chat_engine import (
40
+ StreamChatDeltaReceivedEvent,
41
+ StreamChatErrorEvent,
42
+ )
43
+ from llama_index.core.instrumentation.events.embedding import (
44
+ EmbeddingEndEvent,
45
+ EmbeddingStartEvent,
46
+ )
47
+ from llama_index.core.instrumentation.events.llm import (
48
+ LLMChatEndEvent,
49
+ LLMChatInProgressEvent,
50
+ LLMChatStartEvent,
51
+ LLMCompletionEndEvent,
52
+ LLMCompletionStartEvent,
53
+ LLMPredictEndEvent,
54
+ LLMPredictStartEvent,
55
+ LLMStructuredPredictEndEvent,
56
+ LLMStructuredPredictStartEvent,
57
+ )
58
+ from llama_index.core.instrumentation.events.query import (
59
+ QueryEndEvent,
60
+ QueryStartEvent,
61
+ )
62
+ from llama_index.core.instrumentation.events.rerank import (
63
+ ReRankEndEvent,
64
+ ReRankStartEvent,
65
+ )
66
+ from llama_index.core.instrumentation.events.retrieval import (
67
+ RetrievalEndEvent,
68
+ RetrievalStartEvent,
69
+ )
70
+ from llama_index.core.instrumentation.events.span import (
71
+ SpanDropEvent,
72
+ )
73
+ from llama_index.core.instrumentation.events.synthesis import (
74
+ # GetResponseEndEvent,
75
+ GetResponseStartEvent,
76
+ SynthesizeEndEvent,
77
+ SynthesizeStartEvent,
78
+ )
79
+ from llama_index.core.instrumentation.span import SimpleSpan
80
+ from llama_index.core.instrumentation.span_handlers.base import BaseSpanHandler
81
+ from treelib import Tree
82
+
83
+ if TYPE_CHECKING:
84
+ from llama_index.core.instrumentation.dispatcher import Dispatcher
85
+ from llama_index.core.instrumentation.events import BaseEvent
86
+ from llama_index.core.schema import BaseNode, NodeWithScore
87
+
88
+ #####################################################
89
+ ## Code
90
+ logger = logging.getLogger(__name__)
91
+
92
+ @st.cache_resource
93
+ def get_callback_manager() -> CallbackManager:
94
+ """Create the callback manager for the code."""
95
+ return CallbackManager([LlamaDebugHandler()])
96
+
97
+
98
+ def display_source_node(source_node: NodeWithScore, max_length: int = 100) -> str:
99
+ source_text = source_node.node.get_content().strip()
100
+ source_text = source_text[:max_length] + "..." if len(source_text) > max_length else source_text
101
+ return (
102
+ f"**Node ID:** {source_node.node.node_id}<br>"
103
+ f"**Similarity:** {source_node.score}<br>"
104
+ f"**Text:** {source_text}<br>"
105
+ )
106
+
107
+ class RAGEventHandler(BaseEventHandler):
108
+ """Pruned RAG Event Handler."""
109
+
110
+ # events: List[BaseEvent] = [] # TODO: handle removing historical events if they're too old.
111
+
112
+ @classmethod
113
+ def class_name(cls) -> str:
114
+ """Class name."""
115
+ return "RAGEventHandler"
116
+
117
+ def handle(self, event: BaseEvent, **kwargs: Any) -> None:
118
+ """Logic for handling event."""
119
+ print("-----------------------")
120
+ # all events have these attributes
121
+ print(event.id_)
122
+ print(event.timestamp)
123
+ print(event.span_id)
124
+
125
+ # event specific attributes
126
+ if isinstance(event, LLMChatStartEvent):
127
+ # initial
128
+ print(event.messages)
129
+ print(event.additional_kwargs)
130
+ print(event.model_dict)
131
+ elif isinstance(event, LLMChatInProgressEvent):
132
+ # streaming
133
+ print(event.response.delta)
134
+ elif isinstance(event, LLMChatEndEvent):
135
+ # final response
136
+ print(event.response)
137
+
138
+ # self.events.append(event)
139
+ print("-----------------------")
140
+
141
+ class FullRAGEventHandler(BaseEventHandler):
142
+ """RAG event handler. Built off the example custom event handler.
143
+
144
+ In general, logged events are treated as single events in a point in time,
145
+ that link to a span. The span is a collection of events that are related to
146
+ a single task. The span is identified by a unique span_id.
147
+
148
+ While events are independent, there is some hierarchy.
149
+ For example, in query_engine.query() call with a reranker attached:
150
+ - QueryStartEvent
151
+ - RetrievalStartEvent
152
+ - EmbeddingStartEvent
153
+ - EmbeddingEndEvent
154
+ - RetrievalEndEvent
155
+ - RerankStartEvent
156
+ - RerankEndEvent
157
+ - SynthesizeStartEvent
158
+ - GetResponseStartEvent
159
+ - LLMPredictStartEvent
160
+ - LLMChatStartEvent
161
+ - LLMChatEndEvent
162
+ - LLMPredictEndEvent
163
+ - GetResponseEndEvent
164
+ - SynthesizeEndEvent
165
+ - QueryEndEvent
166
+ """
167
+
168
+ events: ClassVar[list[BaseEvent]] = []
169
+ @classmethod
170
+ def class_name(cls) -> str:
171
+ """Class name."""
172
+ return "RAGEventHandler"
173
+
174
+ def _print_event_nodes(self, event_nodes: Sequence[NodeWithScore | BaseNode]) -> str:
175
+ """Print a list of nodes nicely."""
176
+ output_str = "["
177
+ for node in event_nodes:
178
+ output_str += (str(display_source_node(node, 1000)) + "\n")
179
+ output_str += "* * * * * * * * * * * *"
180
+ output_str += "]"
181
+ return (output_str)
182
+
183
+ def handle(self, event: BaseEvent, **kwargs: Any) -> None:
184
+ """Logic for handling event."""
185
+ logger.info("-----------------------")
186
+ # all events have these attributes
187
+ logger.info(event.id_)
188
+ logger.info(event.timestamp)
189
+ logger.info(event.span_id)
190
+
191
+ # event specific attributes
192
+ logger.info(f"Event type: {event.class_name()}")
193
+ if isinstance(event, AgentRunStepStartEvent):
194
+ # logger.info(event.task_id)
195
+ logger.info(event.step)
196
+ logger.info(event.input)
197
+ if isinstance(event, AgentRunStepEndEvent):
198
+ logger.info(event.step_output)
199
+ if isinstance(event, AgentChatWithStepStartEvent):
200
+ logger.info(event.user_msg)
201
+ if isinstance(event, AgentChatWithStepEndEvent):
202
+ logger.info(event.response)
203
+ if isinstance(event, AgentToolCallEvent):
204
+ logger.info(event.arguments)
205
+ logger.info(event.tool.name)
206
+ logger.info(event.tool.description)
207
+ if isinstance(event, StreamChatDeltaReceivedEvent):
208
+ logger.info(event.delta)
209
+ if isinstance(event, StreamChatErrorEvent):
210
+ logger.info(event.exception)
211
+ if isinstance(event, EmbeddingStartEvent):
212
+ logger.info(event.model_dict)
213
+ if isinstance(event, EmbeddingEndEvent):
214
+ logger.info(event.chunks)
215
+ logger.info(event.embeddings[0][:5]) # avoid printing all embeddings
216
+ if isinstance(event, LLMPredictStartEvent):
217
+ logger.info(event.template)
218
+ logger.info(event.template_args)
219
+ if isinstance(event, LLMPredictEndEvent):
220
+ logger.info(event.output)
221
+ if isinstance(event, LLMStructuredPredictStartEvent):
222
+ logger.info(event.template)
223
+ logger.info(event.template_args)
224
+ logger.info(event.output_cls)
225
+ if isinstance(event, LLMStructuredPredictEndEvent):
226
+ logger.info(event.output)
227
+ if isinstance(event, LLMCompletionStartEvent):
228
+ logger.info(event.model_dict)
229
+ logger.info(event.prompt)
230
+ logger.info(event.additional_kwargs)
231
+ if isinstance(event, LLMCompletionEndEvent):
232
+ logger.info(event.response)
233
+ logger.info(event.prompt)
234
+ if isinstance(event, LLMChatInProgressEvent):
235
+ logger.info(event.messages)
236
+ logger.info(event.response)
237
+ if isinstance(event, LLMChatStartEvent):
238
+ logger.info(event.messages)
239
+ logger.info(event.additional_kwargs)
240
+ logger.info(event.model_dict)
241
+ if isinstance(event, LLMChatEndEvent):
242
+ logger.info(event.messages)
243
+ logger.info(event.response)
244
+ if isinstance(event, RetrievalStartEvent):
245
+ logger.info(event.str_or_query_bundle)
246
+ if isinstance(event, RetrievalEndEvent):
247
+ logger.info(event.str_or_query_bundle)
248
+ # logger.info(event.nodes)
249
+ logger.info(self._print_event_nodes(event.nodes))
250
+ if isinstance(event, ReRankStartEvent):
251
+ logger.info(event.query)
252
+ # logger.info(event.nodes)
253
+ for node in event.nodes:
254
+ logger.info(display_source_node(node))
255
+ logger.info(event.top_n)
256
+ logger.info(event.model_name)
257
+ if isinstance(event, ReRankEndEvent):
258
+ # logger.info(event.nodes)
259
+ logger.info(self._print_event_nodes(event.nodes))
260
+ if isinstance(event, QueryStartEvent):
261
+ logger.info(event.query)
262
+ if isinstance(event, QueryEndEvent):
263
+ logger.info(event.response)
264
+ logger.info(event.query)
265
+ if isinstance(event, SpanDropEvent):
266
+ logger.info(event.err_str)
267
+ if isinstance(event, SynthesizeStartEvent):
268
+ logger.info(event.query)
269
+ if isinstance(event, SynthesizeEndEvent):
270
+ logger.info(event.response)
271
+ logger.info(event.query)
272
+ if isinstance(event, GetResponseStartEvent):
273
+ logger.info(event.query_str)
274
+ self.events.append(event)
275
+ logger.info("-----------------------")
276
+
277
+ def _get_events_by_span(self) -> dict[str, list[BaseEvent]]:
278
+ events_by_span: dict[str, list[BaseEvent]] = {}
279
+ for event in self.events:
280
+ if event.span_id in events_by_span:
281
+ events_by_span[event.span_id].append(event)
282
+ elif (event.span_id is not None):
283
+ events_by_span[event.span_id] = [event]
284
+ return events_by_span
285
+
286
+ def _get_event_span_trees(self) -> list[Tree]:
287
+ events_by_span = self._get_events_by_span()
288
+
289
+ trees = []
290
+ tree = Tree()
291
+
292
+ for span, sorted_events in events_by_span.items():
293
+ # create root node i.e. span node
294
+ tree.create_node(
295
+ tag=f"{span} (SPAN)",
296
+ identifier=span,
297
+ parent=None,
298
+ data=sorted_events[0].timestamp,
299
+ )
300
+ for event in sorted_events:
301
+ tree.create_node(
302
+ tag=f"{event.class_name()}: {event.id_}",
303
+ identifier=event.id_,
304
+ parent=event.span_id,
305
+ data=event.timestamp,
306
+ )
307
+ trees.append(tree)
308
+ tree = Tree()
309
+ return trees
310
+
311
+ def print_event_span_trees(self) -> None:
312
+ """View trace trees."""
313
+ trees = self._get_event_span_trees()
314
+ for tree in trees:
315
+ logger.info(
316
+ tree.show(
317
+ stdout=False, sorting=True, key=lambda node: node.data
318
+ )
319
+ )
320
+ logger.info("")
321
+
322
+
323
+ class RAGSpanHandler(BaseSpanHandler[SimpleSpan]):
324
+ span_dict: dict = {}
325
+
326
+ @classmethod
327
+ def class_name(cls) -> str:
328
+ """Class name."""
329
+ return "ExampleSpanHandler"
330
+
331
+ def new_span(
332
+ self,
333
+ id_: str,
334
+ bound_args: Any,
335
+ instance: Any | None = None,
336
+ parent_span_id: str | None = None,
337
+ **kwargs: Any,
338
+ ) -> SimpleSpan | None:
339
+ """Create a span."""
340
+ # logic for creating a new MyCustomSpan
341
+ if id_ not in self.span_dict:
342
+ self.span_dict[id_] = []
343
+ self.span_dict[id_].append(
344
+ SimpleSpan(id_=id_, parent_id=parent_span_id)
345
+ )
346
+
347
+ def prepare_to_exit_span(
348
+ self,
349
+ id_: str,
350
+ bound_args: Any,
351
+ instance: Any | None = None,
352
+ result: Any | None = None,
353
+ **kwargs: Any,
354
+ ) -> Any:
355
+ """Logic for preparing to exit a span."""
356
+ # if id in self.span_dict:
357
+ # return self.span_dict[id].pop()
358
+
359
+ def prepare_to_drop_span(
360
+ self,
361
+ id_: str,
362
+ bound_args: Any,
363
+ instance: Any | None = None,
364
+ err: BaseException | None = None,
365
+ **kwargs: Any,
366
+ ) -> Any:
367
+ """Logic for preparing to drop a span."""
368
+ # if id in self.span_dict:
369
+ # return self.span_dict[id].pop()
370
+
371
+
372
+ def get_obs() -> Dispatcher:
373
+ """Get observability for the RAG pipeline."""
374
+ dispatcher = get_dispatcher()
375
+ event_handler = RAGEventHandler()
376
+ span_handler = RAGSpanHandler()
377
+
378
+ dispatcher.add_event_handler(event_handler)
379
+ dispatcher.add_span_handler(span_handler)
380
+ return dispatcher
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ libmagic-dev
2
+ poppler-utils
3
+ tesseract-ocr
4
+ pandoc
parsers.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ ### DOCUMENT PROCESSOR [PARSERS]
3
+ #####################################################
4
+ # Jonathan Wang
5
+
6
+ # ABOUT:
7
+ # This project creates an app to chat with PDFs.
8
+
9
+ # This is the PARSERS.
10
+ # It chunks Raw Text into LlamaIndex nodes
11
+ # E.g., by embedding meaning, by sentence, ...
12
+ #####################################################
13
+ # TODO Board:
14
+ # Add more stuff
15
+
16
+ #####################################################
17
+ ## IMPORTS
18
+ from __future__ import annotations
19
+
20
+ from typing import TYPE_CHECKING, Callable, List, Optional
21
+
22
+ from llama_index.core import Settings
23
+ from llama_index.core.node_parser import (
24
+ SemanticSplitterNodeParser,
25
+ SentenceWindowNodeParser,
26
+ )
27
+
28
+ if TYPE_CHECKING:
29
+ from llama_index.core.base.embeddings.base import BaseEmbedding
30
+ from llama_index.core.callbacks import CallbackManager
31
+ from llama_index.core.node_parser.interface import NodeParser
32
+
33
+ # from wtpsplit import SaT
34
+
35
+ # Lazy Loading
36
+
37
+ #####################################################
38
+ ## CODE
39
+ # def sentence_splitter_from_SaT(model: Optional[SaT]) -> Callable[[str], List[str]]:
40
+ # """Convert a SaT model into a sentence splitter function.
41
+
42
+ # Args:
43
+ # model (SaT): The Segment Anything model.
44
+
45
+ # Returns:
46
+ # Callable[[str], List[str]]: The sentence splitting function using the SaT model.
47
+ # """
48
+ # model = model or ss.model
49
+ # if model is None:
50
+ # raise ValueError("Sentence splitting model is not set.")
51
+
52
+ # def sentence_splitter(text: str) -> List[str]:
53
+ # segments = model.split(text_or_texts=text)
54
+ # if isinstance(segments, list):
55
+ # return segments
56
+ # else:
57
+ # return list(segments) # type: ignore (generator is the other option?)
58
+
59
+ # return (sentence_splitter)
60
+
61
+ # @st.cache_resource # can't cache because embed_model is not hashable.
62
+ def get_parser(
63
+ embed_model: BaseEmbedding,
64
+ # sentence_model: Optional[SaT] = None,
65
+ sentence_splitter: Optional[Callable[[str], List[str]]] = None,
66
+ callback_manager: Optional[CallbackManager] = None
67
+ ) -> NodeParser:
68
+ """Parse RAG document processing (main one)."""
69
+ # if (sentence_model is not None) and (sentence_splitter is not None):
70
+ # sentence_splitter = sentence_splitter_from_SaT(sentence_model)
71
+
72
+ return SemanticSplitterNodeParser.from_defaults(
73
+ embed_model=embed_model,
74
+ breakpoint_percentile_threshold=95,
75
+ buffer_size=3,
76
+ sentence_splitter=sentence_splitter,
77
+ callback_manager=callback_manager or Settings.callback_manager,
78
+ include_metadata=True,
79
+ include_prev_next_rel=True,
80
+ )
81
+
82
+
83
+ # @st.cache_resource
84
+ # def get_sentence_parser(splitter_model: Optional[SaT] = None) -> SentenceWindowNodeParser:
85
+ # """Special sentence-level parser to get the document requested info section."""
86
+ # if (splitter_model is not None):
87
+ # sentence_splitter = sentence_splitter_from_SaT(splitter_model)
88
+
89
+ # sentence_parser = SentenceWindowNodeParser.from_defaults(
90
+ # sentence_splitter=sentence_splitter,
91
+ # window_size=0,
92
+ # window_metadata_key="window",
93
+ # original_text_metadata_key="original_text",
94
+ # )
95
+ # return (sentence_parser)
96
+
97
+ def get_sentence_parser() -> SentenceWindowNodeParser:
98
+ """Parse sentences to get the document requested info section."""
99
+ # if (splitter_model is not None):
100
+ # sentence_splitter = sentence_splitter_from_SaT(splitter_model)
101
+ return SentenceWindowNodeParser.from_defaults(
102
+ # sentence_splitter=sentence_splitter,
103
+ window_size=0,
104
+ window_metadata_key="window",
105
+ original_text_metadata_key="original_text",
106
+ )
pdf_reader.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ ### DOCUMENT PROCESSOR [PDF READER]
3
+ #####################################################
4
+ # Jonathan Wang
5
+
6
+ # ABOUT:
7
+ # This project creates an app to chat with PDFs.
8
+
9
+ # This is the PDF READER.
10
+ # It converts a PDF into LlamaIndex nodes
11
+ # using UnstructuredIO.
12
+ #####################################################
13
+ # TODO Board:
14
+ # I don't think the current code is elegent... :(
15
+
16
+ # TODO: Replace chunk_by_header with a custom solution replicating bySimilarity
17
+ # https://docs.unstructured.io/api-reference/api-services/chunking#by-similarity-chunking-strategy
18
+ # Some hybrid thing...
19
+
20
+
21
+ # Come up with a awy to handle summarizing images and tables using MultiModalLLM after the processing into nodes.
22
+ # TODO: Put this into PDFReaderUtilities? Along with the other functions for stuff like email?
23
+
24
+ # Investigate PDFPlumber as a backup/alternative for Unstructured.
25
+ # `https://github.com/jsvine/pdfplumber`
26
+ # nevermind, this is essentially pdfminer.six but nicer
27
+
28
+ # Chunk hierarchy from https://www.reddit.com/r/LocalLLaMA/comments/1dpb9ow/how_we_chunk_turning_pdfs_into_hierarchical/
29
+ # Investigate document parsing algorithms from https://github.com/BobLd/DocumentLayoutAnalysis?tab=readme-ov-file
30
+ # Investigate document parsing algorithms from https://github.com/Filimoa/open-parse?tab=readme-ov-file
31
+
32
+ # Competition:
33
+ # https://github.com/infiniflow/ragflow
34
+ # https://github.com/deepdoctection/deepdoctection
35
+
36
+ #####################################################
37
+ ## IMPORTS
38
+ import os
39
+ import re
40
+ import regex
41
+ from copy import deepcopy
42
+
43
+ from abc import ABC, abstractmethod
44
+ from typing import Any, List, Tuple, IO, Optional, Type, Generic, TypeVar
45
+ from llama_index.core.bridge.pydantic import Field
46
+
47
+ import numpy as np
48
+
49
+ from io import BytesIO
50
+ from base64 import b64encode, b64decode
51
+ from PIL import Image as PILImage
52
+
53
+ # from pdf_reader_utils import clean_pdf_chunk, dedupe_title_chunks, combine_listitem_chunks
54
+
55
+ # Unstructured Document Parsing
56
+ from unstructured.partition.pdf import partition_pdf
57
+ # from unstructured.cleaners.core import clean_extra_whitespace, group_broken_paragraphs #, clean_ordered_bullets, clean_bullets, clean_dashes
58
+ # from unstructured.chunking.title import chunk_by_title
59
+ # Unstructured Element Types
60
+ from unstructured.documents import elements, email_elements
61
+ from unstructured.partition.utils.constants import PartitionStrategy
62
+
63
+ # Llamaindex Nodes
64
+ from llama_index.core.settings import Settings
65
+ from llama_index.core.schema import Document, BaseNode, TextNode, ImageNode, NodeRelationship, RelatedNodeInfo
66
+ from llama_index.core.readers.base import BaseReader
67
+ from llama_index.core.base.embeddings.base import BaseEmbedding
68
+ from llama_index.core.node_parser import NodeParser
69
+
70
+ # Parallelism for cleaning chunks
71
+ from joblib import Parallel, delayed
72
+
73
+ ## Lazy Imports
74
+ # import nltk
75
+ #####################################################
76
+
77
+ # Additional padding around the PDF extracted images
78
+ PDF_IMAGE_HORIZONTAL_PADDING = 20
79
+ PDF_IMAGE_VERTICAL_PADDING = 20
80
+ os.environ['EXTRACT_IMAGE_BLOCK_CROP_HORIZONTAL_PAD'] = str(PDF_IMAGE_HORIZONTAL_PADDING)
81
+ os.environ['EXTRACT_IMAGE_BLOCK_CROP_VERTICAL_PAD'] = str(PDF_IMAGE_VERTICAL_PADDING)
82
+
83
+ # class TextReader(BaseReader):
84
+ # def __init__(self, text: str) -> None:
85
+ # """Init params."""
86
+ # self.text = text
87
+
88
+
89
+ # class ImageReader(BaseReader):
90
+ # def __init__(self, image: Any) -> None:
91
+ # """Init params."""
92
+ # self.image = image
93
+
94
+ GenericNode = TypeVar("GenericNode", bound=BaseNode) # https://mypy.readthedocs.io/en/stable/generics.html
95
+
96
+ class UnstructuredPDFReader():
97
+ # Yes, we could inherit from LlamaIndex BaseReader even though I don't think it's a good idea.
98
+ # Have you seen the Llamaindex Base Reader? It's silly. """OOP"""
99
+ # https://docs.llamaindex.ai/en/stable/api_reference/readers/
100
+
101
+ # here I'm basically cargo culting off the (not-very-good) pre-built Llamaindex one.
102
+ # https://github.com/run-llama/llama_index/blob/main/llama-index-integrations/readers/llama-index-readers-file/llama_index/readers/file/unstructured/base.py
103
+
104
+ # yes I do want to bind these to the class.
105
+ # you better not be changing the embedding model or node parser on me across different PDFReaders. that's absurd.
106
+ # embed_model: BaseEmbedding
107
+ # _node_parser: NodeParser# = Field(
108
+ # description="Node parser to run on each Unstructured Title Chunk",
109
+ # default=Settings.node_parser,
110
+ # )
111
+ _max_characters: int# = Field(
112
+ # description="The maximum number of characters in a node",
113
+ # default=8192,
114
+ # )
115
+ _new_after_n_chars: int #= Field(
116
+ # description="The number of characters after which a new node is created",
117
+ # default=1024,
118
+ # )
119
+ _overlap_n_chars: int #= Field(
120
+ # description="The number of characters to overlap between nodes",
121
+ # default=128,
122
+ # )
123
+ _overlap: int #= Field(
124
+ # description="The number of characters to overlap between nodes",
125
+ # default=128,
126
+ # )
127
+ _overlap_all: bool #= Field(
128
+ # description="Whether to overlap all nodes",
129
+ # default=False,
130
+ # )
131
+ _multipage_sections: bool #= Field(
132
+ # description="Whether to include multipage sections",
133
+ # default=False,
134
+ # )
135
+
136
+ ## TODO: Fix this big ball of primiatives and turn it into a class.
137
+ def __init__(
138
+ self,
139
+ # node_parser: Optional[NodeParser], # Suggest using a SemanticNodeParser.
140
+ max_characters: int = 2048,
141
+ new_after_n_chars: int = 512,
142
+ overlap_n_chars: int = 128,
143
+ overlap: int = 128,
144
+ overlap_all: bool = False,
145
+ multipage_sections: bool = True,
146
+ **kwargs: Any
147
+ ) -> None:
148
+ # node_parser = node_parser or Settings.node_parser
149
+ """Init params."""
150
+ super().__init__(**kwargs)
151
+
152
+ self._max_characters = max_characters
153
+ self._new_after_n_chars = new_after_n_chars
154
+ self._overlap_n_chars = overlap_n_chars
155
+ self._overlap = overlap
156
+ self._overlap_all = overlap_all
157
+ self._multipage_sections = multipage_sections
158
+ # self._node_parser = node_parser or Settings.node_parser # set node parser to run on each Unstructured Title Chunk
159
+
160
+ # Prerequisites for Unstructured.io to work
161
+ # import nltk
162
+ # nltk.data.path = ['./nltk_data']
163
+ # try:
164
+ # if not nltk.data.find("tokenizers/punkt"):
165
+ # # nltk.download("punkt")
166
+ # print("Can't find punkt.")
167
+ # except Exception as e:
168
+ # # nltk.download("punkt")
169
+ # print(e)
170
+ # try:
171
+ # if not nltk.data.find("taggers/averaged_perceptron_tagger"):
172
+ # # nltk.download("averaged_perceptron_tagger")
173
+ # print("Can't find averaged_perceptron_tagger.")
174
+ # except Exception as e:
175
+ # # nltk.download("averaged_perceptron_tagger")
176
+ # print(e)
177
+
178
+
179
+ # """DATA LOADING FUNCTIONS"""
180
+ def _node_rel_prev_next(self, prev_node: GenericNode, next_node: GenericNode) -> Tuple[GenericNode, GenericNode]:
181
+ """Update pre-next node relationships between two nodes."""
182
+ prev_node.relationships[NodeRelationship.NEXT] = RelatedNodeInfo(
183
+ node_id=next_node.node_id,
184
+ metadata={"filename": next_node.metadata['filename']}
185
+ )
186
+ next_node.relationships[NodeRelationship.PREVIOUS] = RelatedNodeInfo(
187
+ node_id=prev_node.node_id,
188
+ metadata={"filename": prev_node.metadata['filename']}
189
+ )
190
+ return (prev_node, next_node)
191
+
192
+ def _node_rel_parent_child(self, parent_node: GenericNode, child_node: GenericNode) -> Tuple[GenericNode, GenericNode]:
193
+ """Update parent-child node relationships between two nodes."""
194
+ parent_node.relationships[NodeRelationship.CHILD] = RelatedNodeInfo(
195
+ node_id=child_node.node_id,
196
+ metadata={"filename": child_node.metadata['filename']}
197
+ )
198
+ child_node.relationships[NodeRelationship.PARENT] = RelatedNodeInfo(
199
+ node_id=parent_node.node_id,
200
+ metadata={"filename": parent_node.metadata['filename']}
201
+ )
202
+ return (parent_node, child_node)
203
+
204
+ def _handle_metadata(
205
+ self,
206
+ pdf_chunk: elements.Element,
207
+ node: GenericNode,
208
+ kept_metadata: List[str] = [
209
+ 'filename', 'file_directory', 'coordinates',
210
+ 'page_number', 'page_name', 'section',
211
+ 'sent_from', 'sent_to', 'subject',
212
+ 'parent_id', 'category_depth',
213
+ 'text_as_html', 'languages',
214
+ 'emphasized_text_contents', 'link_texts', 'link_urls',
215
+ 'is_continuation', 'detection_class_prob',
216
+ ]) -> GenericNode:
217
+ """Add common unstructured element metadata to LlamaIndex node."""
218
+ pdf_chunk_metadata = pdf_chunk.metadata.to_dict() if pdf_chunk.metadata else {}
219
+ current_kept_metadata = deepcopy(kept_metadata)
220
+
221
+ # Handle some interesting keys
222
+ node.metadata['type'] = pdf_chunk.category
223
+ if (('filename' in current_kept_metadata) and ('filename' in pdf_chunk_metadata) and ('file_directory' in pdf_chunk_metadata)):
224
+ filename = os.path.join(str(pdf_chunk_metadata['file_directory']), str(pdf_chunk_metadata['filename']))
225
+ node.metadata['filename'] = filename
226
+ current_kept_metadata.remove('file_directory') if ('file_directory' in current_kept_metadata) else None
227
+ if (('text_as_html' in current_kept_metadata) and ('text_as_html' in pdf_chunk_metadata)):
228
+ node.metadata['orignal_table_text'] = getattr(node, 'text', '')
229
+ node.text = pdf_chunk_metadata['text_as_html']
230
+ current_kept_metadata.remove('text_as_html')
231
+ if (('coordinates' in current_kept_metadata) and (pdf_chunk_metadata.get('coordinates') is not None)):
232
+ node.metadata['coordinates'] = pdf_chunk_metadata['coordinates']
233
+ current_kept_metadata.remove('coordinates')
234
+ if (('page_number' in current_kept_metadata) and ('page_number' in pdf_chunk_metadata)):
235
+ node.metadata['page_number'] = [pdf_chunk_metadata['page_number']] # save as list to allow for multiple pages
236
+ current_kept_metadata.remove('page_number')
237
+ if (('page_name' in current_kept_metadata) and ('page_name' in pdf_chunk_metadata)):
238
+ node.metadata['page_name'] = [pdf_chunk_metadata['page_name']] # save as list to allow for multiple sheets
239
+ current_kept_metadata.remove('page_name')
240
+
241
+ # Handle the remaining keys
242
+ for key in set(current_kept_metadata).intersection(set(pdf_chunk_metadata.keys())):
243
+ node.metadata[key] = pdf_chunk_metadata[key]
244
+
245
+ return node
246
+
247
+ def _handle_text_chunk(self, pdf_text_chunk: elements.Element) -> TextNode:
248
+ """Given a text chunk from Unstructured, convert it to a TextNode for LlamaIndex.
249
+
250
+ Args:
251
+ pdf_text_chunk (elements.Element): Input text chunk from Unstructured.
252
+
253
+ Returns:
254
+ TextNode: LlamaIndex TextNode which saves the text as HTML for structure.
255
+ """
256
+ new_node = TextNode(
257
+ text=pdf_text_chunk.text,
258
+ id_=pdf_text_chunk.id,
259
+ excluded_llm_metadata_keys=['type', 'parent_id', 'depth', 'filename', 'coordinates', 'link_texts', 'link_urls', 'link_start_indexes', 'orig_nodes', 'orignal_table_text', 'languages', 'detection_class_prob', 'keyword_metadata'],
260
+ excluded_embed_metadata_keys=['type', 'parent_id', 'depth', 'filename', 'coordinates', 'page number', 'original_text', 'window', 'link_texts', 'link_urls', 'link_start_indexes', 'orig_nodes', 'orignal_table_text', 'languages', 'detection_class_prob']
261
+ )
262
+ new_node = self._handle_metadata(pdf_text_chunk, new_node)
263
+ return (new_node)
264
+
265
+
266
+ def _handle_table_chunk(self, pdf_table_chunk: elements.Table | elements.TableChunk) -> TextNode:
267
+ """Given a table chunk from Unstructured, convert it to a TextNode for LlamaIndex.
268
+
269
+ Args:
270
+ pdf_table_chunk (elements.Table | elements.TableChunk): Input table chunk from Unstructured
271
+
272
+ Returns:
273
+ TextNode: LlamaIndex TextNode which saves the table as HTML for structure.
274
+
275
+ NOTE: You will need to get the summary of the table for better performance.
276
+ """
277
+ new_node = TextNode(
278
+ text=pdf_table_chunk.metadata.text_as_html if pdf_table_chunk.metadata.text_as_html else pdf_table_chunk.text,
279
+ id_=pdf_table_chunk.id,
280
+ excluded_llm_metadata_keys=['type', 'parent_id', 'depth', 'filename', 'coordinates', 'link_texts', 'link_urls', 'link_start_indexes', 'orig_nodes', 'orignal_table_text', 'languages', 'detection_class_prob', 'keyword_metadata'],
281
+ excluded_embed_metadata_keys=['type', 'parent_id', 'depth', 'filename', 'coordinates', 'page number', 'original_text', 'window', 'link_texts', 'link_urls', 'link_start_indexes', 'orig_nodes', 'orignal_table_text', 'languages', 'detection_class_prob']
282
+ )
283
+ new_node = self._handle_metadata(pdf_table_chunk, new_node)
284
+ return (new_node)
285
+
286
+
287
+ def _handle_image_chunk(self, pdf_image_chunk: elements.Element) -> ImageNode:
288
+ """Given an image chunk from UnstructuredIO, read it in and convert it into a Llamaindex ImageNode.
289
+
290
+ Args:
291
+ pdf_image_chunk (elements.Element): The input image element from UnstructuredIO. We'll allow all types, just in case you want to process some weird chunks.
292
+
293
+ Returns:
294
+ ImageNode: The image saved as a Llamaindex ImageNode.
295
+ """
296
+ pdf_image_chunk_data_available = pdf_image_chunk.metadata.to_dict()
297
+
298
+ # Check for either saved image_path or image_base64/image_mime_type
299
+ if (('image_path' not in pdf_image_chunk_data_available) and ('image_base64' not in pdf_image_chunk_data_available)):
300
+ raise Exception('Image chunk does not have either image_path or image_base64/image_mime_type. Are you sure this is an image?')
301
+
302
+ # Make the image node.
303
+ new_node = ImageNode(
304
+ text=pdf_image_chunk.text,
305
+ id_=pdf_image_chunk.id,
306
+ excluded_llm_metadata_keys=['type', 'parent_id', 'depth', 'filename', 'coordinates', 'link_texts', 'link_urls', 'link_start_indexes', 'orig_nodes', 'languages', 'detection_class_prob', 'keyword_metadata'],
307
+ excluded_embed_metadata_keys=['type', 'parent_id', 'depth', 'filename', 'coordinates', 'page number', 'original_text', 'window', 'link_texts', 'link_urls', 'link_start_indexes', 'orig_nodes', 'languages', 'detection_class_prob']
308
+ )
309
+ new_node = self._handle_metadata(pdf_image_chunk, new_node)
310
+
311
+ # Add image data to image node
312
+ image = None
313
+ if ('image_path' in pdf_image_chunk_data_available):
314
+ # Save image path to image node
315
+ new_node.image_path = pdf_image_chunk_data_available['image_path']
316
+
317
+ # Load image from path, convert to base64
318
+ image_pil = PILImage.open(pdf_image_chunk_data_available['image_path'])
319
+ image_buffer = BytesIO()
320
+ image_pil.save(image_buffer, format='JPEG')
321
+ image = b64encode(image_buffer.getvalue()).decode('utf-8')
322
+
323
+ new_node.image = image
324
+ new_node.image_mimetype = 'image/jpeg'
325
+ del image_buffer, image_pil
326
+ elif ('image_base64' in pdf_image_chunk_data_available):
327
+ # Save image base64 to image node
328
+ new_node.image = pdf_image_chunk_data_available['image_base64']
329
+ new_node.image_mimetype = pdf_image_chunk_data_available['image_mime_type']
330
+
331
+ return (new_node)
332
+
333
+
334
+ def _handle_composite_chunk(self, pdf_composite_chunk: elements.CompositeElement) -> BaseNode:
335
+ """Given a composite chunk from Unstructured, convert it into a node and handle it dependencies as well."""
336
+ # Start by getting a list of all the nodes which were combined into the composite chunk.
337
+ # child_chunks = pdf_composite_chunk.metadata.to_dict()['orig_elements']
338
+ child_chunks = pdf_composite_chunk.metadata.orig_elements or []
339
+ child_nodes = []
340
+ for chunk in child_chunks:
341
+ child_nodes.append(self._handle_chunk(chunk)) # process all the child chunks.
342
+
343
+ # Then build the Composite Chunk into a Node.
344
+ composite_node = self._handle_text_chunk(pdf_text_chunk=pdf_composite_chunk)
345
+ composite_node = self._handle_metadata(pdf_composite_chunk, composite_node)
346
+
347
+ # Set relationships between chunks.
348
+ for index in range(1, len(child_nodes)):
349
+ child_nodes[index-1], child_nodes[index] = self._node_rel_prev_next(child_nodes[index-1], child_nodes[index])
350
+ for index, node in enumerate(child_nodes):
351
+ composite_node, child_nodes[index] = self._node_rel_parent_child(composite_node, child_nodes[index])
352
+
353
+ composite_node.metadata['orig_nodes'] = child_nodes
354
+ composite_node.excluded_llm_metadata_keys = ['filename', 'coordinates', 'chunk_number', 'window', 'orig_nodes', 'languages', 'detection_class_prob', 'keyword_metadata']
355
+ composite_node.excluded_embed_metadata_keys = ['filename', 'coordinates', 'chunk_number', 'page number', 'original_text', 'window', 'summary', 'orig_nodes', 'languages', 'detection_class_prob']
356
+ return(composite_node)
357
+
358
+
359
+ def _handle_chunk(self, chunk: elements.Element) -> BaseNode:
360
+ """Convert Unstructured element chunks to Llamaindex Node. Determine which chunk handling to use based on the element type."""
361
+ # Composite (multiple nodes combined together by chunking)
362
+ if (isinstance(chunk, elements.CompositeElement)):
363
+ return (self._handle_composite_chunk(pdf_composite_chunk=chunk))
364
+ # Tables
365
+ elif ((chunk.category == 'Table') and isinstance(chunk, (elements.Table, elements.TableChunk))):
366
+ return(self._handle_table_chunk(pdf_table_chunk=chunk))
367
+ # Images
368
+ elif (any(True for chunk_info in ['image', 'image_base64', 'image_path'] if chunk_info in chunk.metadata.to_dict())):
369
+ return(self._handle_image_chunk(pdf_image_chunk=chunk))
370
+ # Text
371
+ else:
372
+ return(self._handle_text_chunk(pdf_text_chunk=chunk))
373
+
374
+
375
+ def pdf_to_chunks(
376
+ self,
377
+ file_path: Optional[str],
378
+ file: Optional[IO[bytes]],
379
+ ) -> List[elements.Element]:
380
+ """
381
+ Given the file path to a PDF, read it in with UnstructuredIO and return its elements.
382
+ """
383
+ print("NEWPDF: Partitioning into Chunks...")
384
+ # 1. attempt using AUTO to have it decide.
385
+ # NOTE: this takes care of pdfminer, and also choses between using detectron2 vs tesseract only.
386
+ # However, it sometimes gets confused by PDFs where text elements are added on later, e.g., CIDs for linking, or REDACTED
387
+ pdf_chunks = partition_pdf(
388
+ filename=file_path,
389
+ file=file,
390
+ unique_element_ids=True, # UUIDs that are unique for each element
391
+ strategy=PartitionStrategy.HI_RES, # auto: it decides, hi_res: detectron2, but issues with multi-column, ocr_only: pytesseract, fast: pdfminer
392
+ hi_res_model_name='yolox',
393
+ include_page_breaks=False,
394
+ metadata_filename=file_path,
395
+ infer_table_structure=True,
396
+ extract_images_in_pdf=True,
397
+ extract_image_block_types=['Image', 'Table', 'Formula'], # element types to save as images
398
+ extract_image_block_to_payload=False, # needs to be false; we'll convert into base64 later.
399
+ extract_forms=False, # not currently available
400
+ extract_image_block_output_dir=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data/pdfimgs/')
401
+ )
402
+
403
+ # # 2. Check if it got good output.
404
+ # pdf_read_in_okay = self.check_pdf_read_in(pdf_file_path=pdf_file_path, pdf_file=pdf_file, pdf_chunks=pdf_chunks)
405
+ # if (pdf_read_in_okay):
406
+ # return pdf_chunks
407
+
408
+ # # 3. Okay, PDF didn't read in well, so we'll use the back-up strategy
409
+ # # According to Unstructured's Github: https://github.com/Unstructured-IO/unstructured/blob/main/unstructured/partition/pdf.py
410
+ # # that is "OCR_ONLY" as opposed to "HI_RES".
411
+ # pdf_chunks = partition_pdf(
412
+ # filename=pdf_file_path,
413
+ # file=pdf_file,
414
+ # strategy="ocr_only" # auto: it decides, hi_res: detectron2, but issues with multi-column, ocr_only: pytesseract, fast: pdfminer
415
+ # )
416
+ return pdf_chunks
417
+
418
+
419
+ def chunks_to_nodes(self, pdf_chunks: List[elements.Element]) -> List[BaseNode]:
420
+ """
421
+ Given a PDF from Unstructured broken by header,
422
+ convert them into nodes using the node_parser.
423
+ E.g., to have all sentences with similar meaning as a node, use the SemanticNodeParser
424
+ """
425
+ # 0. Setup.
426
+ unstructured_chunk_nodes = []
427
+
428
+ # Hash of node ID and index
429
+ node_id_to_index = {}
430
+
431
+ # 1. Convert each page's text to Nodes.
432
+ for index, chunk in enumerate(pdf_chunks):
433
+ # Create new node based on node type
434
+ new_node = self._handle_chunk(chunk)
435
+
436
+ # Update hash of node ID and index
437
+ node_id_to_index[new_node.id_] = index
438
+
439
+ # Add relationship to prior node
440
+ if (len(unstructured_chunk_nodes) > 0):
441
+ unstructured_chunk_nodes[-1], new_node = self._node_rel_prev_next(prev_node=unstructured_chunk_nodes[-1], next_node=new_node)
442
+
443
+ # Add parent-child relationships for Title Chunks
444
+ if (chunk.metadata.parent_id is not None):
445
+ # Find the index of the parent node based on parent_id
446
+ parent_index = node_id_to_index[chunk.metadata.parent_id]
447
+ if (parent_index is not None):
448
+ unstructured_chunk_nodes[parent_index], new_node = self._node_rel_parent_child(parent_node=unstructured_chunk_nodes[parent_index], child_node=new_node)
449
+
450
+ # Append to list
451
+ unstructured_chunk_nodes.append(new_node)
452
+
453
+ del node_id_to_index
454
+
455
+ ## TODO: Move this chunk into a separate ReaderPostProcessor thing into PDFReaderUtils. Bundle in the sumamrization for tables and images into this.
456
+ # 2. Node Parse each page to split when new information is different
457
+ # NOTE: This was built for the Semantic Parser, but I guess we'll technically allow any parser here.
458
+ # unstructured_parsed_nodes = self._node_parser.get_nodes_from_documents(unstructured_chunk_nodes)
459
+
460
+ # 3. Node Attributes
461
+ # for index, node in enumerate(unstructured_parsed_nodes):
462
+ # # Keywords and Summary
463
+ # # node_keywords = ', '.join(pdfrutils.get_keywords(node.text, top_k=5))
464
+ # # node_summary = get_t5_summary(node.text, summary_length=64) # get_t5_summary
465
+ # node.metadata['keywords'] = node_keywords
466
+ # # node.metadata['summary'] = node_summary + (("\n" + node.metadata['summary']) if node.metadata['summary'] is not None else "")
467
+
468
+ # # Get additional information about the node.
469
+ # # Email: check for address.
470
+ # info_types = []
471
+ # if (pdfrutils.has_date(node.text)):
472
+ # info_types.append("date")
473
+ # if (pdfrutils.has_email(node.text)):
474
+ # info_types.append("contact email")
475
+ # if (pdfrutils.has_mail_addr(node.text)):
476
+ # info_types.append("mailing postal address")
477
+ # if (pdfrutils.has_phone(node.text)):
478
+ # info_types.append("contact phone")
479
+
480
+ # node.metadata['information types'] = ", ".join(info_types)
481
+ # node.excluded_llm_metadata_keys = ['filename', 'coordinates', 'chunk_number', 'window', 'orig_nodes']
482
+ # node.excluded_embed_metadata_keys = ['filename', 'coordinates', 'chunk_number', 'page number', 'original_text', 'window', 'keywords', 'summary', 'orig_nodes']
483
+
484
+ # if (index > 0):
485
+ # unstructured_parsed_nodes[index-1], node = self._node_rel_prev_next(unstructured_parsed_nodes[index-1], node)
486
+ return(unstructured_chunk_nodes)
487
+
488
+ # """Main user-interaction function"""
489
+ def load_data(
490
+ self,
491
+ file_path: Optional[str] = None,
492
+ file: Optional[IO[bytes]] = None
493
+ ) -> List: #[GenericNode]:
494
+ """Given a path to a PDF file, load it with Unstructured and convert it into a list of Llamaindex Base Nodes.
495
+ Input:
496
+ - pdf_file_path (str): the path to the PDF file.
497
+ Output:
498
+ - List[GenericNode]: a list of LlamaIndex nodes. Creates one node for each parsed node, for each Unstructured Title Chunk.
499
+ """
500
+ # 1. PDF to Chunks
501
+ print("NEWPDF: Reading Input File...")
502
+ pdf_chunks = self.pdf_to_chunks(file_path=file_path, file=file)
503
+ # return (pdf_chunks)
504
+
505
+ # Chunk processing
506
+ # pdf_chunks = clean_pdf_chunk, dedupe_title_chunks, combine_listitem_chunks, remove_header_footer_pagenum
507
+
508
+ # 2. Chunks to titles
509
+ # TODO: I hate this, make our own chunker.
510
+ # pdf_titlechunks = chunk_by_title(
511
+ # pdf_chunks,
512
+ # max_characters=self._max_characters,
513
+ # new_after_n_chars=self._new_after_n_chars,
514
+ # overlap=self._overlap,
515
+ # overlap_all=self._overlap_all,
516
+ # multipage_sections=self._multipage_sections,
517
+ # include_orig_elements=True,
518
+ # combine_text_under_n_chars=self._new_after_n_chars
519
+ # )
520
+ # 3. Cleaning
521
+ # pdf_titlechunks = Parallel(n_jobs=max(int(os.cpu_count())-1, 1))( # type: ignore
522
+ # delayed(self.clean_pdf_chunk)(chunk) for chunk in pdf_chunks # pdf_titlechunks
523
+ # )
524
+ # pdf_titlechunks = list(pdf_titlechunks)
525
+ # 4. Headlines to llamaindex nodes
526
+ print("NEWPDF: Converting chunks to nodes...")
527
+ parsed_chunks = self.chunks_to_nodes(pdf_chunks)
528
+ return (parsed_chunks)
pdf_reader_utils.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ ### DOCUMENT PROCESSOR [PDF READER UTILITIES]
3
+ #####################################################
4
+ # Jonathan Wang
5
+
6
+ # ABOUT:
7
+ # This project creates an app to chat with PDFs.
8
+
9
+ # This is the PDF READER UTILITIES.
10
+ # It defines helper functions for the PDF reader,
11
+ # such as getting Keywords or finding Contact Info.
12
+ #####################################################
13
+ ### TODO Board:
14
+ # Better Summarizer than T5, which has been stripped out?
15
+ # Better keywords than the RAKE+YAKE fusion we're currently using?
16
+ # Consider using GPE/GSP tagging with spacy to confirm mailing addresses?
17
+
18
+ # Handle FigureCaption somehow.
19
+ # Skip Header if it has a Page X or other page number construction.
20
+
21
+ # Detect images that are substantially overlapping according to coordinates.
22
+ # https://stackoverflow.com/questions/49897531/detect-overlapping-images-in-pil
23
+ # Keep them in the following order: no confidence score, larger image, higher confidence score
24
+
25
+ # Detect nodes whose text is substantially repeated at either the top or bottom of the page.
26
+ # Utilize the coordinates to ignore the text on the top and bottom two lines.
27
+
28
+ # Fix OCR issues with spell checking?
29
+
30
+ # Remove images that are too small in size, and overlapping with text boxes.
31
+
32
+ # Convert the List[BaseNode] -> List[BaseNode] functions into TransformComponents
33
+
34
+ #####################################################
35
+ ### Imports
36
+ from __future__ import annotations
37
+
38
+ import difflib
39
+ import re
40
+ from collections import defaultdict
41
+ from copy import deepcopy
42
+ from typing import (
43
+ TYPE_CHECKING,
44
+ List,
45
+ Optional,
46
+ Tuple,
47
+ TypeVar,
48
+ )
49
+
50
+ import rapidfuzz
51
+ import regex
52
+ from llama_index.core.schema import (
53
+ BaseNode,
54
+ NodeRelationship,
55
+ RelatedNodeInfo,
56
+ )
57
+
58
+ if TYPE_CHECKING:
59
+ from unstructured.documents import elements
60
+
61
+ #####################################################
62
+ ### CODE
63
+
64
+ GenericNode = TypeVar("GenericNode", bound=BaseNode)
65
+
66
+ def clean_pdf_chunk(pdf_chunk: elements.Element) -> elements.Element:
67
+ """Given a single element of text from a pdf read by Unstructured, clean its text."""
68
+ ### NOTE: Don't think it's work making this a separate TransformComponent.
69
+ # We'd still need to clean bad characters from the reader.
70
+ chunk_text = pdf_chunk.text
71
+ if (len(chunk_text) > 0):
72
+ # Clean any control characters which break the language detection for other parts of the reader.
73
+ re_bad_chars = regex.compile(r"[\p{Cc}\p{Cs}]+")
74
+ chunk_text = re_bad_chars.sub("", chunk_text)
75
+
76
+ # Remove PDF citations text
77
+ chunk_text = re.sub("\\(cid:\\d+\\)", "", chunk_text) # matches (cid:###)
78
+ # Clean whitespace and broken paragraphs
79
+ # chunk_text = clean_extra_whitespace(chunk_text)
80
+ # chunk_text = group_broken_paragraphs(chunk_text)
81
+ # Save cleaned text.
82
+ pdf_chunk.text = chunk_text
83
+
84
+ return pdf_chunk
85
+
86
+
87
+ def clean_abbreviations(pdf_chunks: list[GenericNode]) -> list[GenericNode]:
88
+ """Remove any common abbreviations in the text which can confuse the sentence model.
89
+
90
+ Args:
91
+ pdf_chunks (List[GenericNode]): List of llama-index nodes.
92
+
93
+ Returns:
94
+ List[GenericNode]: The nodes with cleaned text, abbreviations replaced.
95
+ """
96
+ for pdf_chunk in pdf_chunks:
97
+ text = getattr(pdf_chunk, "text", "")
98
+ if (text == ""):
99
+ continue
100
+ # No. -> Number
101
+ text = re.sub(r"\bNo\b\.\s", "Number", text, flags=re.IGNORECASE)
102
+ # Fig. -> Figure
103
+ text = re.sub(r"\bFig\b\.", "Figure", text, flags=re.IGNORECASE)
104
+ # Eq. -> Equation
105
+ text = re.sub(r"\bEq\b\.", "Equation", text, flags=re.IGNORECASE)
106
+ # Mr. -> Mr
107
+ text = re.sub(r"\bMr\b\.", "Mr", text, flags=re.IGNORECASE)
108
+ # Mrs. -> Mrs
109
+ text = re.sub(r"\bMrs\b\.", "Mrs", text, flags=re.IGNORECASE)
110
+ # Dr. -> Dr
111
+ text = re.sub(r"\bDr\b\.", "Dr", text, flags=re.IGNORECASE)
112
+ # Jr. -> Jr
113
+ text = re.sub(r"\bJr\b\.", "Jr", text, flags=re.IGNORECASE)
114
+ # etc. -> etc
115
+ text = re.sub(r"\betc\b\.", "etc", text, flags=re.IGNORECASE)
116
+ pdf_chunk.text = text
117
+
118
+ return pdf_chunks
119
+
120
+
121
+ def _remove_chunk(
122
+ pdf_chunks: list[GenericNode],
123
+ chunk_index: int | None=None,
124
+ chunk_id: str | None=None
125
+ ) -> list[GenericNode]:
126
+ """Given a list of chunks, remove the chunk at the given index or with the given id.
127
+
128
+ Args:
129
+ pdf_chunks (List[GenericNode]): The list of chunks.
130
+ chunk_index (Optional[int]): The index of the chunk to remove.
131
+ chunk_id (Optional[str]): The id of the chunk to remove.
132
+
133
+ Returns:
134
+ List[GenericNode]: The updated list of chunks, without the removed chunk.
135
+ """
136
+ if (chunk_index is None and chunk_id is None):
137
+ msg = "_remove_chunk: Either chunk_index or chunk_id must be set."
138
+ raise ValueError(msg)
139
+
140
+ # Convert chunk_id to chunk_index
141
+ elif (chunk_index is None):
142
+ chunk = next((c for c in pdf_chunks if c.node_id == chunk_id), None)
143
+ if chunk is not None:
144
+ chunk_index = pdf_chunks.index(chunk)
145
+ else:
146
+ msg = f"_remove_chunk: No chunk found with id {chunk_id}."
147
+ raise ValueError(msg)
148
+ elif (chunk_index < 0 or chunk_index >= len(pdf_chunks)):
149
+ msg = f"_remove_chunk: Chunk {chunk_index} is out of range. Maximum index is {len(pdf_chunks) - 1}."
150
+ raise ValueError(msg)
151
+
152
+ # Update the previous-next node relationships around that index
153
+ def _node_rel_prev_next(prev_node: GenericNode, next_node: GenericNode) -> tuple[GenericNode, GenericNode]:
154
+ """Update pre-next node relationships between two nodes."""
155
+ prev_node.relationships[NodeRelationship.NEXT] = RelatedNodeInfo(
156
+ node_id=next_node.node_id,
157
+ metadata={"filename": next_node.metadata["filename"]}
158
+ )
159
+ next_node.relationships[NodeRelationship.PREVIOUS] = RelatedNodeInfo(
160
+ node_id=prev_node.node_id,
161
+ metadata={"filename": prev_node.metadata["filename"]}
162
+ )
163
+ return (prev_node, next_node)
164
+
165
+ if (chunk_index > 0 and chunk_index < len(pdf_chunks) - 1):
166
+ pdf_chunks[chunk_index - 1], pdf_chunks[chunk_index + 1] = _node_rel_prev_next(prev_node=pdf_chunks[chunk_index - 1], next_node=pdf_chunks[chunk_index + 1])
167
+
168
+ popped_chunk = pdf_chunks.pop(chunk_index)
169
+ chunk_id = chunk_id or popped_chunk.node_id
170
+
171
+ # Remove any references to the removed chunk in node relationships or metadata
172
+ for node in pdf_chunks:
173
+ node.relationships = {k: v for k, v in node.relationships.items() if v.node_id != chunk_id}
174
+ node.metadata = {k: v for k, v in node.metadata.items() if ((isinstance(v, list) and (chunk_id in v)) or (v != chunk_id))}
175
+ return pdf_chunks
176
+
177
+
178
+ def _clean_overlap_text(
179
+ text1: str,
180
+ text2: str,
181
+ combining_text: str=" ",
182
+ min_length: int | None = 1,
183
+ max_length: int | None = 50,
184
+ overlap_threshold: float = 0.9
185
+ ) -> str:
186
+ r"""Remove any overlapping text between two strings.
187
+
188
+ Args:
189
+ text1 (str): The first string.
190
+ text2 (str): The second string.
191
+ combining_text (str, optional): The text to combine the two strings with. Defaults to space (' '). Can also be \n.
192
+ min_length (int, optional): The minimum length of the overlap. Defaults to 1. None is no minimum.
193
+ max_length (int, optional): The maximum length of the overlap. Defaults to 50. None is no maximum.
194
+ overlap_threshold (float, optional): The threshold for being an overlap. Defaults to 0.8.
195
+
196
+ Returns:
197
+ str: The strings combined with the overlap removed.
198
+ """
199
+ for overlap_len in range(min(len(text1), len(text2), (max_length or len(text1))), ((min_length or 1)-1), -1):
200
+ end_substring = text1[-overlap_len:]
201
+ start_substring = text2[:overlap_len]
202
+ similarity = difflib.SequenceMatcher(None, end_substring, start_substring).ratio()
203
+ if (similarity >= overlap_threshold):
204
+ return combining_text.join([text1[:-overlap_len], text2[overlap_len:]]).strip()
205
+
206
+ return combining_text.join([text1, text2]).strip()
207
+
208
+
209
+ def _combine_chunks(c1: GenericNode, c2: GenericNode) -> GenericNode:
210
+ """Combine two chunks into one.
211
+
212
+ Args:
213
+ c1 (GenericNode): The first chunk.
214
+ c2 (GenericNode): The second chunk.
215
+
216
+ Returns:
217
+ GenericNode: The combined chunk.
218
+ """
219
+ # Metadata merging
220
+ # Type merging
221
+ text_types = ["NarrativeText", "ListItem", "Formula", "UncategorizedText", "Composite-TextOnly"]
222
+ image_types = ["FigureCaption", "Image"] # things that make Image nodes.
223
+
224
+ def _combine_chunks_type(c1_type: str, c2_type: str) -> str:
225
+ """Combine the types of two chunks.
226
+
227
+ Args:
228
+ c1_type (str): The type of the first chunk.
229
+ c2_type (str): The type of the second chunk.
230
+
231
+ Returns:
232
+ str: The type of the combined chunk.
233
+ """
234
+ if (c1_type == c2_type):
235
+ return c1_type
236
+ elif (c1_type in text_types and c2_type in text_types):
237
+ return "Composite-TextOnly"
238
+ elif (c1_type in image_types and c2_type in image_types):
239
+ return "Image" # Add caption to image
240
+ else:
241
+ return "Composite"
242
+
243
+ c1_type = c1.metadata["type"]
244
+ c2_type = c2.metadata["type"]
245
+ c1.metadata["type"] = _combine_chunks_type(c1_type, c2_type)
246
+
247
+ # All other metadata merging
248
+ for k, v in c2.metadata.items():
249
+ if k not in c1.metadata:
250
+ c1.metadata[k] = v
251
+ # Merge lists
252
+ elif k in ["page_number", 'page_name', 'languages', 'emphasized_text_contents', 'link_texts', 'link_urls']:
253
+ if not isinstance(c1.metadata[k], list):
254
+ c1.metadata[k] = list(c1.metadata[k])
255
+ if (v not in c1.metadata[k]):
256
+ # Add to list, dedupe
257
+ c1.metadata[k].extend(v)
258
+ c1.metadata[k] = sorted(set(c1.metadata[k]))
259
+
260
+ # Text merging
261
+ c1_text = getattr(c1, "text", "")
262
+ c2_text = getattr(c2, "text", "")
263
+ if (c1_text == c2_text):
264
+ # No duplicates.
265
+ return c1
266
+ if (c1_text == "" or c2_text == ""):
267
+ c1.text = c1_text + c2_text
268
+ return c1
269
+
270
+ # Check if a sentence has been split between two chunks
271
+ # Option 1: letters
272
+ c1_text_last = c1_text[-1]
273
+
274
+ # Check if c1_text_last has a lowercase letter, digit, or punctuation that doesn't end a sentence
275
+ if (re.search(r'[\da-z\[\]\(\)\{\}\<\>\%\^\&\"\'\:\;\,\/\-\_\+\= \t\n\r]', c1_text_last)):
276
+ # We can probably combine these two texts as if they were on the same line.
277
+ c1.text = _clean_overlap_text(c1_text, c2_text, combining_text=" ")
278
+ else:
279
+ # We'll treat these as if they were on separate lines.
280
+ c1.text = _clean_overlap_text(c1_text, c2_text, combining_text="\n")
281
+
282
+ # NOTE: Relationships merging is handled in other functions, because it requires looking back at prior prior chunks.
283
+ return c1
284
+
285
+ def dedupe_title_chunks(pdf_chunks: list[GenericNode]) -> list[GenericNode]:
286
+ """Given a list of chunks, return a list of chunks without any title duplicates.
287
+
288
+ Args:
289
+ pdf_chunks (List[BaseNode]): The list of chunks to have titles deduped.
290
+
291
+ Returns:
292
+ List[BaseNode]: The deduped list of chunks.
293
+ """
294
+ index = 0
295
+ while (index < len(pdf_chunks)):
296
+ if (
297
+ (pdf_chunks[index].metadata["type"] in ("Title")) # is title
298
+ and (index > 0) # is not first chunk
299
+ and (pdf_chunks[index - 1].metadata["type"] in ("Title")) # previous chunk is also title
300
+ ):
301
+ # if (getattr(pdf_chunks[index], 'text', None) != getattr(pdf_chunks[index - 1], 'text', '')):
302
+ # pdf_chunks[index].text = getattr(pdf_chunks[index - 1], 'text', '') + '\n' + getattr(pdf_chunks[index], 'text', '')
303
+ pdf_chunks[index] = _combine_chunks(pdf_chunks[index - 1], pdf_chunks[index])
304
+
305
+ # NOTE: We'll remove the PRIOR title, since duplicates AND child relationships are built on the CURRENT title.
306
+ # There shouldn't be any PARENT/CHILD relationships to the title that we are deleting, so this seems fine.
307
+ pdf_chunks = _remove_chunk(pdf_chunks=pdf_chunks, chunk_index=index-1)
308
+ # NOTE: don't need to shift index because we removed an element.
309
+ else:
310
+ # We don't care about any situations other than consecutive title chunks.
311
+ index += 1
312
+
313
+ return (pdf_chunks)
314
+
315
+
316
+ def combine_listitem_chunks(pdf_chunks: list[GenericNode]) -> list[GenericNode]:
317
+ """Given a list of chunks, combine any adjacent chunks which are ListItems into one List.
318
+
319
+ Args:
320
+ pdf_chunks (List[GenericNode]): The list of chunks to combine.
321
+
322
+ Returns:
323
+ List[GenericNode]: The list of chunks with ListItems combined into one List chunk.
324
+ """
325
+ index = 0
326
+ while (index < len(pdf_chunks)):
327
+ if (
328
+ (pdf_chunks[index].metadata["type"] == "ListItem") # is list item
329
+ and (index > 0) # is not first chunk
330
+ and (pdf_chunks[index - 1].metadata["type"] == "ListItem") # previous chunk is also list item
331
+ ):
332
+ # Okay, we have a consecutive list item. Combine into one list.
333
+ # NOTE: We'll remove the PRIOR list item, since duplicates AND child relationships are built on the CURRENT list item.
334
+ # 1. Append prior list item's text to the current list item's text
335
+ # pdf_chunks[index].text = getattr(pdf_chunks[index - 1], 'text', '') + '\n' + getattr(pdf_chunks[index], 'text', '')
336
+ pdf_chunks[index] = _combine_chunks(pdf_chunks[index - 1], pdf_chunks[index])
337
+ # 2. Remove PRIOR list item
338
+ pdf_chunks.pop(index - 1)
339
+ # 3. Replace NEXT relationship from PRIOR list item with the later list item node ID, if prior prior node exists.
340
+ if (index - 2 >= 0):
341
+ pdf_chunks[index - 2].relationships[NodeRelationship.NEXT] = RelatedNodeInfo(
342
+ node_id=pdf_chunks[index].node_id,
343
+ metadata={"filename": pdf_chunks[index].metadata["filename"]}
344
+ )
345
+ # 4. Replace PREVIOUS relationship from LATER list item with the prior prior node ID, if prior prior node exists.
346
+ pdf_chunks[index].relationships[NodeRelationship.PREVIOUS] = RelatedNodeInfo(
347
+ node_id=pdf_chunks[index - 2].node_id,
348
+ metadata={"filename": pdf_chunks[index - 2].metadata['filename']}
349
+ )
350
+ # NOTE: the PARENT/CHILD relationships should be the same as the previous list item, so this seems fine.
351
+ else:
352
+ # We don't care about any situations other than consecutive list item chunks.
353
+ index += 1
354
+ return (pdf_chunks)
355
+
356
+
357
+ def remove_header_footer_repeated(
358
+ pdf_chunks_input: list[GenericNode],
359
+ window_size: int = 3,
360
+ fuzz_threshold: int = 80
361
+ ) -> list[GenericNode]:
362
+ """Given a list of chunks, remove any header/footer chunks that are repeated across pages.
363
+
364
+ Args:
365
+ pdf_chunks (List[GenericNode]): The list of chunks to process.
366
+ window_size (int): The number of chunks to consider at the beginning and end of each page.
367
+ fuzz_threshold (int): The threshold for fuzzy matching of chunk texts.
368
+
369
+ Returns:
370
+ List[GenericNode]: The list of chunks with header/footer chunks removed.
371
+ """
372
+ nodes_to_remove = set() # id's to remove.
373
+ pdf_chunks = deepcopy(pdf_chunks_input)
374
+
375
+ # Build a dictionary of chunks by page number
376
+ chunks_by_page = defaultdict(list)
377
+ for chunk in pdf_chunks:
378
+ chunk_page_number = min(chunk.metadata["page_number"]) if isinstance(chunk.metadata["page_number"], list) else chunk.metadata["page_number"]
379
+ chunks_by_page[chunk_page_number].append(chunk)
380
+
381
+ # Get the first window_size and last window_size chunks on each page
382
+ header_candidates = defaultdict(set) # hashmap of chunk text, and set of chunk ids with that text.
383
+ footer_candidates = defaultdict(set) # hashmap of chunk text, and set of chunk ids with that text.
384
+ page_number_regex = re.compile(r"(?:-|\( ?)?\b(?:page|p\.?(?:[pg](?:\b|\.)?)?)? ?(?:\d+|\b[ivxm]+\b)\.?(?: ?-|\))?\b", re.IGNORECASE)
385
+ for chunks in chunks_by_page.values():
386
+ header_chunks = chunks[:window_size]
387
+ footer_chunks = chunks[-window_size:]
388
+
389
+ for chunk in header_chunks:
390
+ chunk_text = getattr(chunk, "text", "")
391
+ if chunk.metadata["type"] == "Header" and len(chunk_text) > 0:
392
+ chunk_text_is_pagenum_only = page_number_regex.match(chunk_text)
393
+ if chunk_text_is_pagenum_only and (len(chunk_text_is_pagenum_only.group(0)) == len(chunk_text)):
394
+ # Full match!
395
+ chunk.text = "Page Number Only"
396
+ nodes_to_remove.add(chunk.node_id)
397
+ elif chunk_text_is_pagenum_only and len(chunk_text_is_pagenum_only.group(0)) > 0:
398
+ # Remove the page number content from the chunk text for this exercise
399
+ chunk_text = page_number_regex.sub('', chunk_text)
400
+ chunk.text = chunk_text
401
+
402
+ if chunk.metadata["type"] not in ("Image", "Table") and len(chunk_text) > 0:
403
+ header_candidates[chunk_text].add(chunk.node_id)
404
+
405
+ for chunk in footer_chunks:
406
+ chunk_text = getattr(chunk, "text", "")
407
+ if chunk.metadata["type"] == "Footer" and len(chunk_text) > 0:
408
+ chunk_text_is_pagenum_only = page_number_regex.match(chunk_text)
409
+ if chunk_text_is_pagenum_only and (len(chunk_text_is_pagenum_only.group(0)) == len(chunk_text)):
410
+ # Full match!
411
+ chunk.text = "Page Number Only"
412
+ nodes_to_remove.add(chunk.node_id)
413
+ elif chunk_text_is_pagenum_only and len(chunk_text_is_pagenum_only.group(0)) > 0:
414
+ # Remove the page number content from the chunk text for this exercise
415
+ chunk_text = page_number_regex.sub('', chunk_text)
416
+ chunk.text = chunk_text
417
+
418
+ if chunk.metadata["type"] not in ("Image", "Table") and len(chunk_text) > 0:
419
+ footer_candidates[chunk_text].add(chunk.node_id)
420
+
421
+ # Identify any texts which are too similar to other header texts.
422
+ header_texts = list(header_candidates.keys())
423
+ header_distance_matrix = rapidfuzz.process.cdist(header_texts, header_texts, scorer=rapidfuzz.fuzz.ratio, score_cutoff=fuzz_threshold)
424
+
425
+ footer_texts = list(footer_candidates.keys())
426
+ footer_distance_matrix = rapidfuzz.process.cdist(footer_texts, footer_texts, scorer=rapidfuzz.fuzz.ratio, score_cutoff=fuzz_threshold)
427
+ # Combine header candidates which are too similar to each other in the distance matrix
428
+ for i in range(len(header_distance_matrix)-1):
429
+ for j in range(i+1, len(header_distance_matrix)):
430
+ if i == j:
431
+ continue
432
+ if header_distance_matrix[i][j] >= fuzz_threshold:
433
+ header_candidates[header_texts[i]].update(header_candidates[header_texts[j]])
434
+ header_candidates[header_texts[j]].update(header_candidates[header_texts[i]])
435
+
436
+ for i in range(len(footer_distance_matrix)-1):
437
+ for j in range(i+1, len(footer_distance_matrix)):
438
+ if i == j:
439
+ continue
440
+ if footer_distance_matrix[i][j] >= fuzz_threshold:
441
+ footer_candidates[footer_texts[i]].update(footer_candidates[footer_texts[j]])
442
+ footer_candidates[footer_texts[j]].update(footer_candidates[footer_texts[i]])
443
+
444
+ headers_to_remove = set()
445
+ for chunk_ids in header_candidates.values():
446
+ if len(chunk_ids) > 1:
447
+ headers_to_remove.update(chunk_ids)
448
+
449
+ footers_to_remove = set()
450
+ for chunk_ids in footer_candidates.values():
451
+ if len(chunk_ids) > 1:
452
+ footers_to_remove.update(chunk_ids)
453
+
454
+ nodes_to_remove = nodes_to_remove.union(headers_to_remove.union(footers_to_remove))
455
+
456
+ for node_id in nodes_to_remove:
457
+ pdf_chunks = _remove_chunk(pdf_chunks=pdf_chunks, chunk_id=node_id)
458
+
459
+ return pdf_chunks
460
+
461
+ def remove_overlap_images(pdf_chunks: list[GenericNode]) -> list[GenericNode]:
462
+ # TODO(Jonathan Wang): Implement this function to remove images which are completely overlapping each other
463
+ # OR... get a better dang reader!
464
+ raise NotImplementedError
465
+
466
+
467
+ def chunk_by_header(
468
+ pdf_chunks_in: list[GenericNode],
469
+ combine_text_under_n_chars: int = 1024,
470
+ multipage_sections: bool = True,
471
+ # ) -> Tuple[List[GenericNode], List[GenericNode]]:
472
+ ) -> list[GenericNode]:
473
+ """Combine chunks together that are part of the same header and have similar meaning.
474
+
475
+ Args:
476
+ pdf_chunks (List[GenericNode]): List of chunks to be combined.
477
+
478
+ Returns:
479
+ List[GenericNode]: List of combined chunks.
480
+ List[GenericNode]: List of original chunks, with node references updated.
481
+ """
482
+ # TODO(Jonathan Wang): Handle semantic chunking between elements within a Header chunk.
483
+ # TODO(Jonathan Wang): Handle splitting element chunks if they are over `max_characters` in length (does this ever really happen?)
484
+ # TODO(Jonathan Wang): Handle relationships between nodes.
485
+
486
+ pdf_chunks = deepcopy(pdf_chunks_in)
487
+ output = []
488
+ id_to_index = {}
489
+ index = 0
490
+
491
+ # Pass 1: Combine chunks together that are part of the same title chunk.
492
+ while (index < len(pdf_chunks)):
493
+ chunk = pdf_chunks[index]
494
+ if (chunk.metadata["type"] in ["Header", "Footer", "Image", "Table"]):
495
+ # These go immediately into the semantic title chunks and also reset the new node.
496
+
497
+ # Let's add a newline to distinguish from any other content.
498
+ if (chunk.metadata["type"] in ["Header", "Footer", "Table"]):
499
+ chunk.text = getattr(chunk, "text", "") + "\n"
500
+
501
+ output.append(chunk)
502
+ index += 1
503
+ continue
504
+
505
+ # Make a new node if we have a new title (or if we don't have a title).
506
+ if (
507
+ chunk.metadata["type"] == "Title"
508
+ ):
509
+ # We're good, this node can stay as a TitleChunk.
510
+ chunk.metadata['type'] = 'Composite'
511
+ # if (not isinstance(chunk.metadata['page number'], list)):
512
+ # chunk.metadata['page number'] = [chunk.metadata['page number']]
513
+
514
+ # Let's add a newline to distinguish the title from the content.
515
+ setattr(chunk, 'text', getattr(chunk, 'text', '') + "\n")
516
+
517
+ output.append(chunk)
518
+ id_to_index[chunk.id_] = len(output) - 1
519
+ index += 1
520
+ continue
521
+
522
+ elif (chunk.metadata.get('parent_id', None) in id_to_index):
523
+ # This chunk is part of the same title as a prior chunk.
524
+ # Add this text into the prior title node.
525
+ jndex = id_to_index[chunk.metadata['parent_id']]
526
+
527
+ # if (not isinstance(output[jndex].metadata['page number'], list)):
528
+ # output[jndex].metadata['page number'] = [chunk.metadata['page number']]
529
+
530
+ output[jndex] = _combine_chunks(output[jndex], chunk)
531
+ # output[jndex].text = getattr(output[jndex], 'text', '') + '\n' + getattr(chunk, 'text', '')
532
+ # output[jndex].metadata['page number'] = list(set(output[jndex].metadata['page number'] + [chunk.metadata['page number']]))
533
+ # output[jndex].metadata['languages'] = list(set(output[jndex].metadata['languages'] + chunk.metadata['languages']))
534
+
535
+ pdf_chunks.remove(chunk)
536
+ continue
537
+
538
+ elif (
539
+ (chunk.metadata.get('parent_id', None) is None)
540
+ and (
541
+ len(getattr(chunk, 'text', '')) > combine_text_under_n_chars # big enough text section to stand alone
542
+ or (len(id_to_index.keys()) <= 0) # no prior title
543
+ )
544
+ ):
545
+ # Okay, so either we don't have a title, or it was interrupted by an image / table.
546
+ # This chunk can stay as a TextChunk.
547
+ chunk.metadata['type'] = 'Composite-TextOnly'
548
+ # if (not isinstance(chunk.metadata['page number'], list)):
549
+ # chunk.metadata['page number'] = [chunk.metadata['page number']]
550
+
551
+ output.append(chunk)
552
+ id_to_index[chunk.id_] = len(output) - 1
553
+ index += 1
554
+ continue
555
+
556
+ else:
557
+ # Add the text to the prior node that isn't a table or image.
558
+ jndex = len(output) - 1
559
+ while (
560
+ (jndex >= 0)
561
+ and (output[jndex].metadata['type'] in ['Table', 'Image'])
562
+ ):
563
+ # for title_chunk in output:
564
+ # print(f'''{title_chunk.id_}: {title_chunk.metadata['type']}, text: {title_chunk.text}, parent: {title_chunk.metadata['parent_id']}''')
565
+ jndex -= 1
566
+
567
+ if (jndex < 0):
568
+ raise Exception(f'''Prior title chunk not found: {index}, {chunk.metadata.get('parent_id', None)}''')
569
+
570
+ # Add this text into the prior title node.
571
+ # if (not isinstance(output[jndex].metadata['page number'], list)):
572
+ # output[jndex].metadata['page number'] = [chunk.metadata['page number']]
573
+
574
+ output[jndex] = _combine_chunks(output[jndex], chunk)
575
+ # output[jndex].text = getattr(output[jndex], 'text', '') + ' ' + getattr(chunk, 'text', '')
576
+ # output[jndex].metadata['page number'] = list(set(output[jndex].metadata['page number'] + [chunk.metadata['page number']]))
577
+ # output[jndex].metadata['languages'] = list(set(output[jndex].metadata['languages'] + chunk.metadata['languages']))
578
+
579
+ pdf_chunks.remove(chunk)
580
+ # TODO: Update relationships between nodes.
581
+ continue
582
+
583
+ return (output)
584
+
585
+
586
+ ### TODO:
587
+ # Merge images together that are substantially overlapping.
588
+ # Favour image with no confidence score. (these come straight from pdf).
589
+ # Favour the larger image over the smaller one.
590
+ # Favour the image with higher confidence score.
591
+ def merge_images() -> None:
592
+ pass
prompts.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ ### DOCUMENT PROCESSOR [PROMPTS]
3
+ #####################################################
4
+ # Jonathan Wang
5
+
6
+ # ABOUT:
7
+ # This project creates an app to chat with PDFs.
8
+
9
+ # This is the prompts sent to the LLM.
10
+ #####################################################
11
+ ## TODOS:
12
+ # Use the row names instead of .at indesx locators
13
+ # This is kinda dumb because we read the same .csv file over again
14
+ # Should we structure this abstraction differently?
15
+
16
+ #####################################################
17
+ ## IMPORTS:
18
+ import pandas as pd
19
+ from llama_index.core import PromptTemplate
20
+
21
+ #####################################################
22
+ ## CODE:
23
+
24
+ # https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/prompts/default_prompts.py
25
+ QA_PROMPT = """Context information is below.\n
26
+ ---------------------
27
+ {context_str}
28
+ ---------------------
29
+ Given the context information, answer the query.
30
+ You must adhere to the following rules:
31
+ - Use the context information, not prior knowledge.
32
+ - End the answer with any brief quote(s) from the context that are the most essential in answering the question.
33
+ - If the context is not helpful in answering the question, do not include a quote.
34
+
35
+ Query: {query_str}
36
+ Answer: """
37
+
38
+ # https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/prompts/default_prompts.py
39
+ REFINE_PROMPT = """The original query is as follows: {query_str}
40
+ We have provided an existing answer: {existing_answer}
41
+ We have the opportunity to refine the existing answer (only if needed) with some more context below.
42
+ ---------------------
43
+ {context_msg}
44
+ ---------------------
45
+ Given the new context, refine the original answer to better answer the query.
46
+ You must adhere to the following rules:
47
+ - If the context isn't useful, return the original answer.
48
+ - End the answer with any brief quote(s) from the original answer or new context that are the most essential in answering the question.
49
+ - If the new context is not helpful in answering the question, leave the original answer unchanged.
50
+
51
+ Refined Answer: """
52
+
53
+ def get_qa_prompt(
54
+ # prompt_file_path: str
55
+ ) -> PromptTemplate:
56
+ """Given a path to the prompts, get prompt for Question-Answering"""
57
+ # prompts = pd.read_csv(prompt_file_path)
58
+ # https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/prompts/default_prompts.py
59
+ custom_qa_prompt = PromptTemplate(
60
+ QA_PROMPT
61
+ )
62
+ return (custom_qa_prompt)
63
+
64
+
65
+ def get_refine_prompt(
66
+ # prompt_file_path: str
67
+ ) -> PromptTemplate:
68
+ """Given a path to the prompts, get prompt to Refine answer after new info"""
69
+ # prompts = pd.read_csv(prompt_file_path)
70
+ # https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/prompts/default_prompts.py
71
+ custom_refine_prompt = PromptTemplate(
72
+ REFINE_PROMPT
73
+ )
74
+ return (custom_refine_prompt)
75
+
76
+
77
+ # def get_reqdoc_prompt(
78
+ # prompt_file_path: str
79
+ # ) -> PromptTemplate:
80
+ # """Given a path to the prompts, get prompt to identify requested info from document."""
81
+ # prompts = pd.read_csv(prompt_file_path)
82
+ # # https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/prompts/default_prompts.py
83
+ # reqdoc_prompt = PromptTemplate(
84
+ # prompts.at[2, 'Prompt']
85
+ # )
86
+ # return (reqdoc_prompt)
pyproject.toml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://www.reddit.com/r/Python/comments/13h2xuc/any_musthave_extensions_for_working_with_python/
2
+
3
+ [tool.isort]
4
+ profile = "black"
5
+
6
+ [tool.mypy]
7
+ warn_unused_configs = true
8
+ exclude = "archives|build|docs"
9
+ show_column_numbers = true
10
+ show_error_codes = true
11
+ strict = true
12
+ plugins = ["numpy.typing.mypy_plugin"]
13
+
14
+ [tool.ruff]
15
+ select = ["ALL"]
16
+ ignore = [
17
+ "ANN101", # Missing type annotation for self in method
18
+ "COM", # flake8-commas
19
+ "D100", # Missing docstring in public module
20
+ "D101", # Missing docstring in public class
21
+ "D102", # Missing docstring in public method
22
+ "D103", # Missing docstring in public function
23
+ "D104", # Missing docstring in public package
24
+ "D406", # Section name should end with a newline
25
+ "D407", # Missing dashed underline after section
26
+ "FBT", # flake8-boolean-trap
27
+ "G004", # Logging statement uses f-string
28
+ # "PD901", # df is a bad variable name. Be kinder to your future self.
29
+ "PTH123", # open() should be replaced by Path.open()
30
+ "RET505", # Unnecessary `elif` after `return` statement (I think this improves readability)
31
+ "RET506", # Unnecessary `else` after `return` statement (I think this improves readability)
32
+ "T20", # flake8-print
33
+ "TD003", # Missing issue link on the line following this TODO (I don't have an issue system)
34
+ ]
35
+ src = ["src"]
36
+
37
+ [tool.ruff.per-file-ignores]
38
+ "tests/**/*.py" = [
39
+ "S101", # Use of assert detected
40
+ ]
41
+
42
+ [tool.ruff.pydocstyle]
43
+ convention = "numpy"
44
+
45
+ [tool.pyright]
46
+ typeCheckingMode = "strict"
47
+
48
+ reportMissingTypeStubs = false
49
+ reportPrivateUsage = false
50
+ reportUnknownArgumentType = false
51
+ reportUnknownMemberType = false
52
+ reportUnknownParameterType = false
53
+ reportUnknownVariableType = false
requirements.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+ torch>=2.4.0+cu124
3
+ torchaudio>=2.4.0+cu124
4
+ torchvision>=0.19.0+cu124
5
+ transformers>=4.41.1
6
+ accelerate>=0.28.0
7
+ quanto
8
+ optimum-quanto # bitsandbytes replacement, seems better?
9
+ sentence-transformers
10
+ einops
11
+ einops_exts
12
+ open_clip_torch>=2.24.0
13
+ treelib
14
+ nltk>=3.9
15
+ # multi-rake
16
+ yake
17
+ symspellpy
18
+ rapidfuzz
19
+ streamlit
20
+ streamlit-pdf-viewer
21
+ opencv-python
22
+ pdf2image
23
+ pytesseract
24
+ pdfplumber>=0.11.3
25
+ pdfminer.six>=20231228 # fixes infinite loop from unstructured[all-docs] of PDFMiner Read In?
26
+ unstructured[all-docs]>=0.15.5
27
+ llama-index-core
28
+ llama-index-embeddings-huggingface
29
+ llama-index-vector-stores-qdrant
30
+ llama-index-retrievers-bm25
31
+ llama-index-llms-huggingface
32
+ llama-index-llms-groq
33
+ llama-index-question-gen-openai # required for subquestionqueryengine
34
+ llama-index-multi-modal-llms-openai
retriever.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ ### DOCUMENT PROCESSOR [RETRIEVER]
3
+ #####################################################
4
+ # Jonathan Wang
5
+
6
+ # ABOUT:
7
+ # This project creates an app to chat with PDFs.
8
+
9
+ # This is the RETRIEVER
10
+ # which defines the main way that document
11
+ # snippets are identified.
12
+
13
+ #####################################################
14
+ ## TODO:
15
+
16
+ #####################################################
17
+ ## IMPORTS:
18
+ import logging
19
+ from typing import Optional, List, Tuple, Dict, cast
20
+ from collections import defaultdict
21
+
22
+ import streamlit as st
23
+
24
+ import numpy as np
25
+
26
+ from llama_index.core.utils import truncate_text
27
+ from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
28
+ from llama_index.retrievers.bm25 import BM25Retriever
29
+
30
+ from llama_index.core import VectorStoreIndex #, StorageContext,
31
+ from llama_index.core.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle
32
+ from llama_index.core.callbacks.base import CallbackManager
33
+
34
+ # Own Modules:
35
+ from merger import _merge_on_scores
36
+
37
+ # Lazy Loading:
38
+
39
+ #####################################################
40
+ ## CODE:
41
+ class RAGRetriever(BaseRetriever):
42
+ """
43
+ Jonathan Wang's custom built retriever over our vector store.
44
+ Combination of Hybrid Retrieval (BM25 x Vector Embeddings) + AutoMergingRetriever
45
+ https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/retrievers/auto_merging_retriever.py
46
+ """
47
+ def __init__(
48
+ self,
49
+ vector_store_index: VectorStoreIndex,
50
+
51
+ semantic_top_k: int = 10,
52
+ sparse_top_k: int = 6,
53
+
54
+ fusion_similarity_top_k: int = 10, # total number of snippets to retrieve after the Reicprocal Rerank.
55
+ semantic_weight_fraction: float = 0.6, # percentage weight to give to semantic cosine vs sparse bm25
56
+ merge_up_thresh: float = 0.5, # fraction of nodes needed to be retrieved to merge up to semantic level
57
+
58
+ verbose: bool = True,
59
+ callback_manager: Optional[CallbackManager] = None,
60
+ object_map: Optional[dict] = None,
61
+ objects: Optional[List[IndexNode]] = None,
62
+ ) -> None:
63
+ """Init params."""
64
+ self._vector_store_index = vector_store_index
65
+
66
+ self.sentence_vector_retriever = VectorIndexRetriever(
67
+ index=vector_store_index, similarity_top_k=semantic_top_k
68
+ )
69
+ self.sentence_bm25_retriever = BM25Retriever.from_defaults(
70
+ # nodes=list(vector_store_index.storage_context.docstore.docs.values())
71
+ index=vector_store_index # TODO: Confirm this works.
72
+ , similarity_top_k=sparse_top_k
73
+ )
74
+
75
+ self._fusion_similarity_top_k = fusion_similarity_top_k
76
+ self._semantic_weight_fraction = semantic_weight_fraction
77
+ self._merge_up_thresh = merge_up_thresh
78
+
79
+ super().__init__(
80
+ # callback_manager=callback_manager,
81
+ object_map=object_map,
82
+ objects=objects,
83
+ verbose=verbose,
84
+ )
85
+
86
+
87
+ @classmethod
88
+ def class_name(cls) -> str:
89
+ """Class name."""
90
+ return "RAGRetriever"
91
+
92
+
93
+ def _get_parents_and_merge(
94
+ self, nodes: List[NodeWithScore]
95
+ ) -> Tuple[List[NodeWithScore], bool]:
96
+ """Get parents and merge nodes."""
97
+ # retrieve all parent nodes
98
+ parent_nodes: Dict[str, BaseNode] = {}
99
+ parent_cur_children_dict: Dict[str, List[NodeWithScore]] = defaultdict(list)
100
+ for node in nodes:
101
+ if node.node.parent_node is None:
102
+ continue
103
+ parent_node_info = node.node.parent_node
104
+
105
+ # Fetch actual parent node if doesn't exist in `parent_nodes` cache yet
106
+ parent_node_id = parent_node_info.node_id
107
+ if parent_node_id not in parent_nodes:
108
+ parent_node = self._vector_store_index.storage_context.docstore.get_document(
109
+ parent_node_id
110
+ )
111
+ parent_nodes[parent_node_id] = cast(BaseNode, parent_node)
112
+
113
+ # add reference to child from parent
114
+ parent_cur_children_dict[parent_node_id].append(node)
115
+
116
+ # compute ratios and "merge" nodes
117
+ # merging: delete some children nodes, add some parent nodes
118
+ node_ids_to_delete = set()
119
+ nodes_to_add: Dict[str, BaseNode] = {}
120
+ for parent_node_id, parent_node in parent_nodes.items():
121
+ parent_child_nodes = parent_node.child_nodes
122
+ parent_num_children = len(parent_child_nodes) if parent_child_nodes else 1
123
+ parent_cur_children = parent_cur_children_dict[parent_node_id]
124
+ ratio = len(parent_cur_children) / parent_num_children
125
+
126
+ # if ratio is high enough, merge up to the next level in the hierarchy
127
+ if ratio > self._merge_up_thresh:
128
+ node_ids_to_delete.update(
129
+ set({n.node.node_id for n in parent_cur_children})
130
+ )
131
+
132
+ parent_node_text = truncate_text(getattr(parent_node, 'text', ''), 100)
133
+ info_str = (
134
+ f"> Merging {len(parent_cur_children)} nodes into parent node.\n"
135
+ f"> Parent node id: {parent_node_id}.\n"
136
+ f"> Parent node text: {parent_node_text}\n"
137
+ )
138
+ # logger.info(info_str)
139
+ if self._verbose:
140
+ print(info_str)
141
+
142
+ # add parent node
143
+ # can try averaging score across embeddings for now
144
+ avg_score = sum(
145
+ [n.get_score() or 0.0 for n in parent_cur_children]
146
+ ) / len(parent_cur_children)
147
+ parent_node_with_score = NodeWithScore(
148
+ node=parent_node, score=avg_score
149
+ )
150
+ nodes_to_add[parent_node_id] = parent_node_with_score # type: ignore (NodesWithScore is a child of BaseNode)
151
+
152
+ # delete old child nodes, add new parent nodes
153
+ new_nodes = [n for n in nodes if n.node.node_id not in node_ids_to_delete]
154
+ # add parent nodes
155
+ new_nodes.extend(list(nodes_to_add.values())) # type: ignore (NodesWithScore is a child of BaseNode)
156
+
157
+ is_changed = len(node_ids_to_delete) > 0
158
+ return new_nodes, is_changed
159
+
160
+
161
+ def _fill_in_nodes(
162
+ self, nodes: List[NodeWithScore]
163
+ ) -> Tuple[List[NodeWithScore], bool]:
164
+ """Fill in nodes."""
165
+ new_nodes = []
166
+ is_changed = False
167
+ for idx, node in enumerate(nodes):
168
+ new_nodes.append(node)
169
+ if idx >= len(nodes) - 1:
170
+ continue
171
+
172
+ cur_node = cast(BaseNode, node.node)
173
+ # if there's a node in the middle, add that to the queue
174
+ if (
175
+ cur_node.next_node is not None
176
+ and cur_node.next_node == nodes[idx + 1].node.prev_node
177
+ ):
178
+ is_changed = True
179
+ next_node = self._vector_store_index.storage_context.docstore.get_document(
180
+ cur_node.next_node.node_id
181
+ )
182
+ next_node = cast(BaseNode, next_node)
183
+
184
+ next_node_text = truncate_text(getattr(next_node, 'text', ''), 100) # TODO: why not higher?
185
+ info_str = (
186
+ f"> Filling in node. Node id: {cur_node.next_node.node_id}"
187
+ f"> Node text: {next_node_text}\n"
188
+ )
189
+ # logger.info(info_str)
190
+ if self._verbose:
191
+ print(info_str)
192
+
193
+ # set score to be average of current node and next node
194
+ avg_score = (node.get_score() + nodes[idx + 1].get_score()) / 2
195
+ new_nodes.append(NodeWithScore(node=next_node, score=avg_score))
196
+ return new_nodes, is_changed
197
+
198
+
199
+ def _try_merging(
200
+ self, nodes: List[NodeWithScore]
201
+ ) -> Tuple[List[NodeWithScore], bool]:
202
+ """Try different ways to merge nodes."""
203
+ # first try filling in nodes
204
+ nodes, is_changed_0 = self._fill_in_nodes(nodes)
205
+ # then try merging nodes
206
+ nodes, is_changed_1 = self._get_parents_and_merge(nodes)
207
+ return nodes, is_changed_0 or is_changed_1
208
+
209
+
210
+ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
211
+ """Retrieve."""
212
+ # Get vector stores retrieved nodes
213
+ vector_sentence_nodes = self.sentence_vector_retriever.retrieve(query_bundle)# , **kwargs)
214
+ bm25_sentence_nodes = self.sentence_bm25_retriever.retrieve(query_bundle)# , **kwargs)
215
+
216
+ # Get initial nodes from hybrid search.
217
+ initial_nodes = _merge_on_scores(
218
+ vector_sentence_nodes,
219
+ bm25_sentence_nodes,
220
+ [getattr(a, "score", np.nan) for a in vector_sentence_nodes],
221
+ [getattr(b, "score", np.nan) for b in bm25_sentence_nodes],
222
+ a_weight=self._semantic_weight_fraction,
223
+ top_k=self._fusion_similarity_top_k
224
+ )
225
+
226
+ # Merge nodes
227
+ cur_nodes, is_changed = self._try_merging(list(initial_nodes)) # technically _merge_on_scores returns a sequence.
228
+ while is_changed:
229
+ cur_nodes, is_changed = self._try_merging(cur_nodes)
230
+
231
+ # sort by similarity
232
+ cur_nodes.sort(key=lambda x: x.get_score(), reverse=True)
233
+
234
+ # some other reranking and filtering node postprocessors here?
235
+ # https://docs.llamaindex.ai/en/stable/module_guides/querying/node_postprocessors/root.html
236
+ return cur_nodes
237
+
238
+ @st.cache_resource
239
+ def get_retriever(
240
+ _vector_store_index: VectorStoreIndex,
241
+
242
+ semantic_top_k: int = 10,
243
+ sparse_top_k: int = 6,
244
+
245
+ fusion_similarity_top_k: int = 10, # total number of snippets to retrieve after the Reicprocal Rerank.
246
+ semantic_weight_fraction: float = 0.6, # percentage weight to give to semantic chunks over sentence chunks
247
+ merge_up_thresh: float = 0.5, # fraction of nodes needed to be retrieved to merge up to semantic level
248
+
249
+ verbose: bool = True,
250
+ _callback_manager: Optional[CallbackManager] = None,
251
+ object_map: Optional[dict] = None,
252
+ objects: Optional[List[IndexNode]] = None,
253
+ ) -> BaseRetriever:
254
+ """Get the retriver to use.
255
+
256
+ Args:
257
+ vector_store_index (VectorStoreIndex): The vector store to query on.
258
+ semantic_top_k (int, optional): Top k nodes to retrieve semantically (cosine). Defaults to 10.
259
+ sparse_top_k (int, optional): Top k nodes to retrieve sparsely (BM25). Defaults to 6.
260
+ fusion_similarity_top_k (int, optional): Maximum number of nodes to retrieve after fusing. Defaults to 10.
261
+ callback_manager (Optional[CallbackManager], optional): Callback manager. Defaults to None.
262
+ object_map (Optional[dict], optional): Object map. Defaults to None.
263
+ objects (Optional[List[IndexNode]], optional): Objects list. Defaults to None.
264
+
265
+ Returns:
266
+ BaseRetriever: Retriever to use.
267
+ """
268
+ retriever = RAGRetriever(
269
+ vector_store_index=_vector_store_index,
270
+ semantic_top_k=semantic_top_k,
271
+ sparse_top_k=sparse_top_k,
272
+ fusion_similarity_top_k=fusion_similarity_top_k,
273
+ semantic_weight_fraction=semantic_weight_fraction,
274
+ merge_up_thresh=merge_up_thresh,
275
+ verbose=verbose,
276
+ callback_manager=_callback_manager,
277
+ object_map=object_map,
278
+ objects=objects
279
+ )
280
+ return (retriever)
storage.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ ### DOCUMENT PROCESSOR [STORAGE]
3
+ #####################################################
4
+ # Jonathan Wang
5
+
6
+ # ABOUT:
7
+ # This project creates an app to chat with PDFs.
8
+
9
+ # This is the setup for the Storage in the RAG pipeline.
10
+ #####################################################
11
+ ## TODOS:
12
+ # Handle creating multiple vector stores, one for each document which has been processed (?)
13
+
14
+ #####################################################
15
+ ## IMPORTS:
16
+ import gc
17
+ from torch.cuda import empty_cache
18
+
19
+ from typing import Optional, IO, List, Tuple
20
+
21
+ import streamlit as st
22
+
23
+ import qdrant_client
24
+ from llama_index.core import StorageContext
25
+ from llama_index.core.storage.docstore.types import BaseDocumentStore
26
+ from llama_index.core.storage.docstore import SimpleDocumentStore
27
+ from llama_index.vector_stores.qdrant import QdrantVectorStore
28
+ from llama_index.core import VectorStoreIndex
29
+
30
+ from llama_index.core.settings import Settings
31
+ from llama_index.core.base.embeddings.base import BaseEmbedding
32
+ from llama_index.core.node_parser import NodeParser
33
+
34
+ # Reader and processing
35
+ from pdf_reader import UnstructuredPDFReader
36
+ from pdf_reader_utils import clean_abbreviations, dedupe_title_chunks, combine_listitem_chunks, remove_header_footer_repeated, chunk_by_header
37
+ from metadata_adder import UnstructuredPDFPostProcessor
38
+
39
+ #####################################################
40
+ # Get Vector Store
41
+ @st.cache_resource
42
+ def get_vector_store() -> QdrantVectorStore:
43
+ qdr_client = qdrant_client.QdrantClient(
44
+ location=":memory:"
45
+ )
46
+ qdr_aclient = qdrant_client.AsyncQdrantClient(
47
+ location=":memory:"
48
+ )
49
+ return QdrantVectorStore(client=qdr_client, aclient=qdr_aclient, collection_name='pdf', prefer_grpc=True)
50
+
51
+
52
+ # Get Document Store from List of Documents
53
+ # @st.cache_resource # can't hash a list.
54
+ def get_docstore(documents: List) -> BaseDocumentStore:
55
+ """Get the document store from a list of documents."""
56
+ docstore = SimpleDocumentStore()
57
+ docstore.add_documents(documents)
58
+ return docstore
59
+
60
+
61
+ # Get storage context and
62
+ # @st.cache_resource # can't cache the pdf_reader or vector_store
63
+ # def pdf_to_storage(
64
+ # pdf_file_path: Optional[str],
65
+ # pdf_file: Optional[IO[bytes]],
66
+ # _pdf_reader: UnstructuredPDFReader,
67
+ # _embed_model: BaseEmbedding,
68
+ # _node_parser: Optional[NodeParser] = None,
69
+ # _pdf_postprocessor: Optional[UnstructuredPDFPostProcessor] = None,
70
+ # _vector_store: Optional[QdrantVectorStore]=None,
71
+ # ) -> Tuple[StorageContext, VectorStoreIndex]:
72
+ # """Read in PDF and save to storage."""
73
+
74
+ # # Read the PDF with the PDF reader
75
+ # pdf_chunks = _pdf_reader.load_data(pdf_file_path=pdf_file_path, pdf_file=pdf_file)
76
+
77
+ # # Clean the PDF chunks
78
+ # # Insert any cleaners here.
79
+
80
+ # # TODO: Cleaners to remove repeated header/footer text, overlapping elements, ...
81
+ # pdf_chunks = clean_abbreviations(pdf_chunks)
82
+ # pdf_chunks = dedupe_title_chunks(pdf_chunks)
83
+ # pdf_chunks = combine_listitem_chunks(pdf_chunks)
84
+ # pdf_chunks = remove_header_footer_repeated(pdf_chunks)
85
+ # empty_cache()
86
+ # gc.collect()
87
+
88
+ # # Postprocess the PDF nodes.
89
+ # if (_node_parser is None):
90
+ # _node_parser = Settings.node_parser
91
+
92
+ # # Combine by semantic headers
93
+ # pdf_chunks = chunk_by_header(pdf_chunks, 1000)
94
+ # pdf_chunks = _node_parser.get_nodes_from_documents(pdf_chunks)
95
+
96
+ # if (_pdf_postprocessor is not None):
97
+ # pdf_chunks = _pdf_postprocessor(pdf_chunks)
98
+
99
+ # # Add embeddings
100
+ # pdf_chunks = _embed_model(pdf_chunks)
101
+
102
+ # # Create Document Store
103
+ # docstore = get_docstore(documents=pdf_chunks)
104
+
105
+ # # Create Vector Store if not provided
106
+ # if (_vector_store is None):
107
+ # _vector_store = get_vector_store()
108
+
109
+ # ## TODO: Handle images in StorageContext.
110
+
111
+ # # Save into Storage
112
+ # storage_context = StorageContext.from_defaults(
113
+ # docstore=docstore,
114
+ # vector_store=_vector_store
115
+ # )
116
+ # vector_store_index = VectorStoreIndex(
117
+ # pdf_chunks, storage_context=storage_context
118
+ # )
119
+
120
+ # return (storage_context, vector_store_index)
summary.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################
2
+ ### DOCUMENT PROCESSOR [Summarizer]
3
+ #####################################################
4
+ ### Jonathan Wang
5
+
6
+ # ABOUT:
7
+ # This creates an app to chat with PDFs.
8
+
9
+ # This is the Summarizer
10
+ # Which creates summaries based on documents.
11
+ #####################################################
12
+ ### TODO Board:
13
+ # Summary Index for document?
14
+
15
+ # https://docs.llamaindex.ai/en/stable/examples/response_synthesizers/tree_summarize/
16
+ # https://sourajit16-02-93.medium.com/text-summarization-unleashed-novice-to-maestro-with-llms-and-instant-code-solutions-8d26747689c4
17
+
18
+ #####################################################
19
+ ### PROGRAM SETTINGS
20
+
21
+
22
+ #####################################################
23
+ ### PROGRAM IMPORTS
24
+ import logging
25
+
26
+ from typing import Optional, Sequence, Any, Callable, cast
27
+ from llama_index.core.bridge.pydantic import Field, PrivateAttr
28
+
29
+ from llama_index.core.settings import Settings
30
+ from llama_index.core.base.llms.base import BaseLLM
31
+ from llama_index.core.multi_modal_llms import MultiModalLLM
32
+ from llama_index.core.schema import BaseNode, TextNode, ImageDocument
33
+ from llama_index.core.callbacks.base import CallbackManager
34
+
35
+ from llama_index.core.response_synthesizers import TreeSummarize
36
+
37
+ # Own Modules
38
+ from metadata_adder import ModelMetadataAdder
39
+
40
+ #####################################################
41
+ ### CONSTANTS
42
+ logger = logging.getLogger(__name__)
43
+
44
+ DEFAULT_SUMMARY_TEMPLATE = """You are an expert summarizer of information. You are given some information from a document. Summarize the information, and then provide the key information that can be drawn from it. The information is below:
45
+ {context_str}
46
+ """
47
+
48
+ DEFAULT_ONELINE_SUMMARY_TEMPLATE = """You are an expert summarizer of information. You are given a summary of a document. In no more than three sentences, describe the subject of the document, the main ideas of the document, and what types of questions can be answered from it."""
49
+
50
+ DEFAULT_TREE_SUMMARY_TEMPLATE = """You are an expert summarizer of information. You are given some text from a document.
51
+ Please provide a comprehensive summary of the text.
52
+ Include the main subject of the text, the key points or topics, and the most important conclusions if there are any.
53
+ The summary should be detailed yet concise."""
54
+
55
+ DEFAULT_TABLE_SUMMARY_TEMPLATE = """You are an expert summarizer of tables. You are given a table or part of a table in HTML format. The table is below:
56
+ {context_str}
57
+ ----------------
58
+ Summarize the table, and then provide the key insights that can be drawn directly from the table. If this is not actually an HTML table or part of an HTML table, please do not respond.
59
+ """
60
+
61
+ DEFAULT_IMAGE_SUMMARY_TEMPLATE = """You are an expert image summarizer. You are given an image. Summarize the image, and then provide the key insights that can be drawn directly from the image, if there are any.
62
+ """
63
+
64
+ #####################################################
65
+ ### SCRIPT
66
+
67
+ class TextSummaryMetadataAdder(ModelMetadataAdder):
68
+ """Adds metadata to nodes based on a language model."""
69
+
70
+ _llm: BaseLLM = PrivateAttr()
71
+
72
+ def __init__(
73
+ self,
74
+ metadata_name: str,
75
+ llm: Optional[BaseLLM] = None,
76
+ prompt_template: Optional[str] = DEFAULT_SUMMARY_TEMPLATE,
77
+ **kwargs: Any
78
+ ) -> None:
79
+ """Init params."""
80
+ llm = llm or Settings.llm
81
+ prompt_template = prompt_template if prompt_template is not None else DEFAULT_SUMMARY_TEMPLATE
82
+ super().__init__(metadata_name=metadata_name, prompt_template=prompt_template, **kwargs)
83
+
84
+ @classmethod
85
+ def class_name(cls) -> str:
86
+ return "TextSummaryMetadataAdder"
87
+
88
+ def get_node_metadata(self, node: BaseNode) -> Optional[str]:
89
+ if (getattr(node, 'text', None) is None):
90
+ return None
91
+
92
+ response = self._llm.complete(prompt=self.prompt_template.format(context_str=node.text))
93
+ return response.text
94
+
95
+
96
+ class TableSummaryMetadataAdder(ModelMetadataAdder):
97
+ """Adds table summary metadata to a document.
98
+
99
+ Args:
100
+ metadata_name: The name of the metadata to add to the document. Defaults to 'table_summary'.
101
+ llm: The LLM to use to generate the table summary. Defaults to Settings llm.
102
+ prompt_template: The prompt template to use to generate the table summary. Defaults to DEFAULT_TABLE_SUMMARY_TEMPLATE.
103
+ """
104
+ _llm: BaseLLM = PrivateAttr()
105
+
106
+ def __init__(
107
+ self,
108
+ metadata_name: str = "table_summary", ## TODO: This is a bad pattern, string should not be hardcoded like this
109
+ llm: Optional[BaseLLM] = None,
110
+ prompt_template: Optional[str] = DEFAULT_TABLE_SUMMARY_TEMPLATE,
111
+ # num_workers: int = 1,
112
+ **kwargs: Any,
113
+ ) -> None:
114
+ """Init params."""
115
+ llm = llm or Settings.llm
116
+ prompt_template = prompt_template or DEFAULT_TABLE_SUMMARY_TEMPLATE
117
+ super().__init__(metadata_name=metadata_name, prompt_template=prompt_template, **kwargs)
118
+ self._llm = llm
119
+
120
+ @classmethod
121
+ def class_name(cls) -> str:
122
+ return "TableSummaryMetadataAdder"
123
+
124
+ def get_node_metadata(self, node: BaseNode) -> Optional[str]:
125
+ """Given a node, get the metadata for the node using the language model."""
126
+ ## NOTE: Our PDF Reader parser distringuishes between TextNode and TableNode using the 'orignal_table_text' attribute.
127
+ ## BUG (future): `orignal_table_text` should not be hardcoded.
128
+ if (not isinstance(node, TextNode)):
129
+ return None
130
+ if (node.metadata.get('orignal_table_text') is None):
131
+ return None
132
+ if (getattr(node, 'text', None) is None):
133
+ return None
134
+
135
+ response = self._llm.complete(
136
+ self.prompt_template.format(context_str=node.text)
137
+ )
138
+ return response.text
139
+
140
+
141
+ class ImageSummaryMetadataAdder(ModelMetadataAdder):
142
+ """Adds image summary metadata to a document.
143
+
144
+ Args:
145
+ metadata_name: The name of the metadata to add to the document. Defaults to 'table_summary_metadata'.
146
+ """
147
+ _llm: MultiModalLLM = PrivateAttr()
148
+
149
+ def __init__(
150
+ self,
151
+ llm: MultiModalLLM,
152
+ prompt_template: str = DEFAULT_IMAGE_SUMMARY_TEMPLATE,
153
+ metadata_name: str = 'image_summary',
154
+ **kwargs: Any,
155
+ ) -> None:
156
+ """Init params."""
157
+ super().__init__(metadata_name=metadata_name, prompt_template=prompt_template, **kwargs)
158
+ self._llm = llm
159
+
160
+ @classmethod
161
+ def class_name(cls) -> str:
162
+ return "ImageSummaryMetadataAdder"
163
+
164
+ def _get_image_node_metadata(self, node: BaseNode) -> Optional[str]:
165
+ """Handles getting images from image nodes.
166
+
167
+ Args:
168
+ node (BaseNode): The image node to get the image summary for. NOTE: This can technically be any type of node so long as it has an image stored.
169
+
170
+ Returns:
171
+ Optional[str]: The image summary if it exists. If not, return None.
172
+ """
173
+ if (
174
+ ((getattr(node, 'image', None) is None) and (getattr(node, 'image_path', None) is None))
175
+ or (not callable(getattr(node, "resolve_image", None))) # method used to convert node to PILImage for model.
176
+ ):
177
+ # Not a valid image node with image attributes and image conversion.
178
+ return None
179
+
180
+ # Check whethr the image is of text or not
181
+ ### TODO: Replace this with a text-overlap thing.
182
+ image = node.resolve_image() # type: ignore | we check for this above.
183
+ im_width, im_height = image.size
184
+ if (im_width < 70): # TODO: this really should be based on the average text width / whether this is overlapping text.
185
+ return None
186
+
187
+ ## NOTE: We're assuming that the llm complete function has a parameter `images` to send image node(s) as input.
188
+ ## This is NOT necessarily true if the end user decides to create their own implementation of a MultiModalLLM.
189
+ response = self._llm.complete(
190
+ prompt=self.prompt_template,
191
+ image_documents=[
192
+ cast(ImageDocument, node) # NOTE: This is a hack. Technically, node should be an ImageNode, a parent of ImageDocument; but I don't think we'll be using the Document features so this should be okay.
193
+ ],
194
+ )
195
+ return response.text
196
+
197
+ def _get_composite_node_metadata(self, node: BaseNode) -> Optional[str]:
198
+ """Handles getting images from composite nodes (i.e., where an image is stored as a original node inside a composite node).
199
+
200
+ Args:
201
+ node (TextNode): The composite node to get the image summary for.
202
+
203
+ Returns:
204
+ Optional[str]: The image summary if it exists. If not, return None.
205
+ """
206
+ if ('orig_nodes' not in node.metadata):
207
+ return None # no image nodes in the composite node.
208
+
209
+ output = ""
210
+ for orig_node in node.metadata['orig_nodes']:
211
+ output += str(self._get_image_node_metadata(orig_node) or "")
212
+
213
+ if (output == ""):
214
+ return None
215
+ return output
216
+
217
+ def get_node_metadata(self, node: BaseNode) -> Optional[str]:
218
+ """Get the image summary for a node (or subnodes)."""
219
+
220
+ if (node.metadata['type'].startswith('Composite')):
221
+ return self._get_composite_node_metadata(node)
222
+ else:
223
+ return self._get_image_node_metadata(node)
224
+
225
+
226
+ def get_tree_summarizer(
227
+ llm: Optional[BaseLLM] = None,
228
+ callback_manager: Optional[CallbackManager] = None,
229
+ ):
230
+ llm = llm or Settings.llm
231
+ tree_summarizer = TreeSummarize(llm=llm, callback_manager=callback_manager)
232
+ return (tree_summarizer)
233
+
234
+
235
+ def get_tree_summary(tree_summarizer: TreeSummarize, text_chunks: Sequence[BaseNode]) -> str:
236
+ """Summarize the text nodes using a tree summarizer.
237
+
238
+ Args:
239
+ tree_summarizer (TreeSummarize): The tree summarizer to use.
240
+ text_chunks (Sequence[BaseNode]): The text nodes to summarize.
241
+
242
+ Returns:
243
+ str: The summarized text.
244
+ """
245
+ response = tree_summarizer.aget_response(query_str=DEFAULT_TREE_SUMMARY_TEMPLATE, text_chunks=[getattr(text_chunks, 'text') for text_chunks in text_chunks if hasattr(text_chunks, 'text')])
246
+ return response.response